In [1]:
# add command for jupyter to reload modules automatically
%load_ext autoreload
%autoreload 2

In [2]:
import os

import comet_ml

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

from config import Config, load_config
from data import CLEVRSplit, CLEVRTextSplit, CLEVRMultimodalSplit, CollatorForMaskedLanguageModeling
from model import Model, TextualModel, TrainingModel, MultimodalModel


import lightning as L
from lightning import Trainer
from lightning.pytorch.loggers.comet import CometLogger
from lightning.pytorch.callbacks import ModelCheckpoint

In [3]:
def log_to_comet(): return False

In [4]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# config = Config()
config = load_config()
config.multimodal_pretraining = True

In [5]:
if config.use_txt_scene:
    train_dataset, test_dataset, systematic_dataset = CLEVRTextSplit.build_splits(config)
elif config.multimodal_pretraining:
    train_dataset, test_dataset, systematic_dataset = CLEVRMultimodalSplit.build_splits(config)
else:
    train_dataset, test_dataset, systematic_dataset = CLEVRSplit.build_splits(config)

config.pad_idx = train_dataset.pad_idx

In [6]:
dlkwargs = {
    'batch_size': config.batch_size,
    'num_workers': int(os.environ.get("SLURM_JOB_CPUS_PER_NODE", 4)),
    'pin_memory': torch.cuda.is_available(),
}

collator = CollatorForMaskedLanguageModeling(config, train_dataset.processor, mlm_probability=0.15)
train_loader = DataLoader(train_dataset, shuffle=True, collate_fn=collator, **dlkwargs)
test_loader = DataLoader(test_dataset, shuffle=False, collate_fn=collator, **dlkwargs)
systematic_loader = DataLoader(systematic_dataset, shuffle=False, collate_fn=collator, **dlkwargs)

if config.use_txt_scene:
    model = TextualModel(config)
else:
    model = Model(config)
training_model = TrainingModel(model, config)

  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."


In [7]:
if config.multimodal_pretraining:
    model = MultimodalModel(config)
elif config.use_txt_scene:
    model = TextualModel(config)
else:
    model = Model(config)
    
model = model.to(device)

In [8]:
batch = [b.to(device) for b in next(iter(train_loader))]
# processor = train_dataset.processor

In [11]:
output.shape, labels.shape

(torch.Size([64, 359, 96]), torch.Size([64, 359]))

In [12]:
images, scenes, labels = batch
output = model(images, scenes)
loss = F.cross_entropy(output.transpose(1,2), labels)

pred = output.argmax(dim=-1, keepdim=True)  # get the index of the max log-probability
masked = labels != -100
correct = pred.eq(labels.view_as(pred))
correct = correct[masked].sum().detach()
count = torch.ones_like(pred)
count = count[masked].sum().detach()
acc = correct / count

loss, acc

(tensor(4.4971, device='cuda:0', grad_fn=<NllLoss2DBackward0>),
 tensor(0.0213, device='cuda:0'))

In [15]:
output.shape

torch.Size([64, 359, 96])

In [14]:
labels.numel()

22976