In [1]:
import json
from pathlib import Path
import random

import matplotlib
import matplotlib.pyplot as plt

import os

import torch

from data import CLEVRTextSplit, CLEVRMultimodalSplit, Scene

In [2]:
from config import load_config
    
config = load_config()
# config.shuffle_object_identities = True
config.display_object_properties = False
config.batch_size = 256

In [3]:
config.vocabulary_path

'/workspace1/fidelrio/CLEVR_CoGenT_v1.0/vocab.txt'

In [4]:
train_dataset, test_dataset, systematic_dataset = CLEVRMultimodalSplit.build_splits(
    config
)

In [5]:
object_lens = []
for i in range(len(train_dataset)):
    scene = train_dataset.scenes[i]
    object_lens.append(len(scene['objects']))

In [6]:
max_epochs = 1000
schedule = [tuple(range(3,i+1)) for i in range(3,10)] + [(0,)]

n_segments = len(schedule) # uniform segments

In [7]:
schedule

[(3,),
 (3, 4),
 (3, 4, 5),
 (3, 4, 5, 6),
 (3, 4, 5, 6, 7),
 (3, 4, 5, 6, 7, 8),
 (3, 4, 5, 6, 7, 8, 9),
 (0,)]

In [8]:
segment_len = max_epochs // n_segments

In [9]:
n_segments, segment_len

(8, 125)

In [10]:
intervals = list(zip(range(0, max_epochs, segment_len),range(segment_len, max_epochs+segment_len, segment_len)))

In [11]:
intervals

[(0, 125),
 (125, 250),
 (250, 375),
 (375, 500),
 (500, 625),
 (625, 750),
 (750, 875),
 (875, 1000)]

In [12]:
import numpy as np

[(s[0],s[-1]+1) for s in np.array_split(range(max_epochs), n_segments)]

[(0, 125),
 (125, 250),
 (250, 375),
 (375, 500),
 (500, 625),
 (625, 750),
 (750, 875),
 (875, 1000)]

In [13]:
current_epoch = 999

current_stage = [stage for stage, interval in enumerate(intervals) if current_epoch in range(*interval)][0]

In [14]:
current_stage

7

In [15]:
from data import CurriculumData

In [16]:
config.multimodal_pretraining = True

data = CurriculumData(config)
data.setup('fit')

In [17]:
# data.train_with_n_objects = 0,
data.train_with_n_objects = tuple(range(3,11))

In [18]:
%%time
ds = data.train_dataloader().dataset

CPU times: user 6.54 s, sys: 633 ms, total: 7.18 s
Wall time: 7.19 s


In [19]:
object_lens = []
for i in range(len(ds)):
    scene = ds.scenes[i]
    object_lens.append(len(scene['objects']))

In [20]:
from collections import Counter
Counter(object_lens).most_common()

[(5, 8826),
 (6, 8813),
 (10, 8775),
 (8, 8766),
 (4, 8754),
 (9, 8702),
 (3, 8686),
 (7, 8678)]

In [23]:
for n_objects in schedule:
    data.train_with_n_objects = n_objects
    print(n_objects, len(data.train_dataloader()))

(3,) 34
(3, 4) 69
(3, 4, 5) 103
(3, 4, 5, 6) 138
(3, 4, 5, 6, 7) 171
(3, 4, 5, 6, 7, 8) 206
(3, 4, 5, 6, 7, 8, 9) 240
(0,) 274


In [22]:
ds = data.train_dataloader().dataset