In [1]:
import os

import comet_ml

import torch
from torch.utils.data import DataLoader

from config import Config, load_config
from data import CLEVRSplit, CLEVRTextSplit
from model import Model, TextualModel, TrainingModel

from tqdm.auto import tqdm 

import lightning as L
from lightning import Trainer
from lightning.pytorch.loggers.comet import CometLogger
from lightning.pytorch.callbacks import ModelCheckpoint


def log_to_comet():
    return False

In [2]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = load_config()
config.use_txt_scene = True

if config.use_txt_scene:
    train_dataset, test_dataset, systematic_dataset = CLEVRTextSplit.build_splits(config)
else:
    train_dataset, test_dataset, systematic_dataset = CLEVRSplit.build_splits(config)
    
config.pad_idx = train_dataset.pad_idx

experiment_name = "notebook"

Building vocabulary


  0%|          | 0/699960 [00:00<?, ?it/s]

Building answers index


  0%|          | 0/699960 [00:00<?, ?it/s]

In [55]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
#     print(batch[0])
    scenes, questions, answers = zip(*batch)
    inputs = [torch.cat(comb) for comb in zip(questions, scenes)]
    inputs = pad_sequence(inputs, batch_first=True, padding_value=config.pad_idx)
    answers = torch.tensor(answers)
    return inputs, answers

In [56]:
dlkwargs = {
    'batch_size': config.batch_size,
    'num_workers': int(os.environ.get("SLURM_JOB_CPUS_PER_NODE", 4)),
    'pin_memory': torch.cuda.is_available(),
}

train_loader = DataLoader(train_dataset, shuffle=True, collate_fn=collate_fn, **dlkwargs)
test_loader = DataLoader(test_dataset, shuffle=False, collate_fn=collate_fn, **dlkwargs)
systematic_loader = DataLoader(systematic_dataset, shuffle=False, collate_fn=collate_fn, **dlkwargs)

In [57]:
for ds in train_dataset, test_dataset, systematic_dataset:
    ds.processor.pad_questions = False
    ds.processor.pad_scenes = False

In [58]:
%%time
[_ for _ in tqdm(train_loader)]
None

  0%|          | 0/10937 [00:03<?, ?it/s]

CPU times: user 15.9 s, sys: 7.2 s, total: 23.1 s
Wall time: 3min 44s


In [68]:
batch = next(iter(train_loader))

In [69]:
batch

[tensor([[ 0, 86, 26,  ...,  5, 74, 11],
         [ 0, 93, 65,  ...,  1,  1,  1],
         [ 0, 39, 86,  ...,  1,  1,  1],
         ...,
         [ 0, 93, 77,  ...,  1,  1,  1],
         [ 0, 93, 65,  ...,  1,  1,  1],
         [ 0, 21, 87,  ...,  1,  1,  1]]),
 tensor([15,  1, 20,  0, 24, 27, 20, 20, 18,  1, 23, 15, 20, 15, 27,  0, 20,  1,
         19, 23,  1, 20, 27,  0, 27, 27, 27, 11,  1,  1, 23, 19, 17, 24, 20,  1,
         27, 20,  5, 26, 13, 27,  3,  3, 24, 13, 20,  1, 27, 20, 20, 27, 20, 20,
         20,  0, 20, 18, 21, 25, 23, 25,  1, 20])]

In [70]:
batch[0].shape, batch[1].shape

(torch.Size([64, 292]), torch.Size([64]))

In [71]:
(batch[0] == config.pad_idx).sum() / batch[0].numel()

tensor(0.3469)

In [72]:
(batch[1] == config.pad_idx).sum() / batch[0].numel()

tensor(0.0005)

In [73]:
# (torch.cat((batch[0], batch[1]), dim=1) == config.pad_idx).sum() / torch.cat((batch[0], batch[1]), dim=1).numel()

In [74]:
# torch.cat((batch[0], batch[1]), dim=1).numel()

In [75]:
batch[0].numel()

18688