In [1]:
from CUBDatasets import CUBImageFt, CUBSentences
import torch

In [2]:
# If the below raise LookupError about punkt, run this
# import nltk
# nltk.download('punkt')

In [3]:
RAWDATA_PATH = 'data'

In [4]:
tx = lambda data: torch.Tensor(data)

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
img_train = CUBImageFt(RAWDATA_PATH, split='train', device=device)
img_test = CUBImageFt(RAWDATA_PATH, split='test', device=device)

In [7]:
len(img_train), len(img_test)

(8855, 2933)

In [8]:
# If the below raise LookupError about punkt, run this
# import nltk
# nltk.download('punkt')

In [9]:
maxSentLen = 32
txt_train = CUBSentences(RAWDATA_PATH, split='train', transform=tx, max_sequence_length=maxSentLen)
txt_test = CUBSentences(RAWDATA_PATH, split='test', transform=tx, max_sequence_length=maxSentLen)
# Each entry of the dataset is a 2-tuple (padded sentence embedding, actual length)

In [10]:
# Ten sentences for each image
len(txt_train), len(txt_test)

(88550, 29330)

## Test

In [10]:
img_training_generator = torch.utils.data.DataLoader(img_train, batch_size=200, shuffle=True)
len(img_training_generator)

45

In [11]:
# for i, img in enumerate(img_training_generator):
#     print(i, img.shape)

In [12]:
img_testing_generator = torch.utils.data.DataLoader(img_test, batch_size=200, shuffle=True)
len(img_testing_generator)

15

In [13]:
# for i, img in enumerate(img_testing_generator):
#     print(i, img.shape)

In [14]:
txt_training_generator = torch.utils.data.DataLoader(txt_train, batch_size=2000, shuffle=True)
len(txt_training_generator)

45

In [15]:
# for i, txt in enumerate(txt_training_generator):
#     print(i, len(txt), txt[0].shape, txt[1].shape)

In [16]:
txt_testing_generator = torch.utils.data.DataLoader(txt_test, batch_size=2000, shuffle=True)
len(txt_testing_generator)

15

In [17]:
# for i, txt in enumerate(txt_testing_generator):
#     print(i, len(txt), txt[0].shape, txt[1].shape)

In [18]:
# EOS = 2
# for i, l in enumerate(txt[1]):
#     try:
#         print(torch.where(txt[0][i] == EOS)[0][0] + 1 == l)
#     except:
#         print(txt[0][i])

## Joint Dataloader

In [19]:
class CUB(torch.utils.data.Dataset):
    def __init__(self, img_data_dir, txt_data_dir, split, device, transform=None, **kwargs):
        """split: 'train' or 'test' """
        super().__init__()
        self.CUBtxt = CUBSentences(txt_data_dir, split=split, transform=transform, **kwargs)
        self.CUBimg = CUBImageFt(img_data_dir, split=split, device=device)
        
    def __len__(self):
        return len(self.CUBtxt)
    
    def __getitem__(self, idx):
        txt = self.CUBtxt.__getitem__(idx)
        img = self.CUBimg.__getitem__(idx // 10)
        return img, txt

In [13]:
from joint_dataset import CUB

In [14]:
CUB_train = CUB(RAWDATA_PATH, RAWDATA_PATH, 'train', device, tx)

In [15]:
training_generator = torch.utils.data.DataLoader(CUB_train, batch_size=2000, shuffle=False)
len(training_generator)

45

In [16]:
for i, (img, txt) in enumerate(training_generator):
    print(i, torch.unique(img, dim=0).shape, txt[0].shape)

0 torch.Size([200, 2048]) torch.Size([2000, 32])
1 torch.Size([200, 2048]) torch.Size([2000, 32])
2 torch.Size([200, 2048]) torch.Size([2000, 32])
3 torch.Size([200, 2048]) torch.Size([2000, 32])
4 torch.Size([200, 2048]) torch.Size([2000, 32])
5 torch.Size([200, 2048]) torch.Size([2000, 32])
6 torch.Size([200, 2048]) torch.Size([2000, 32])
7 torch.Size([200, 2048]) torch.Size([2000, 32])
8 torch.Size([200, 2048]) torch.Size([2000, 32])
9 torch.Size([200, 2048]) torch.Size([2000, 32])
10 torch.Size([200, 2048]) torch.Size([2000, 32])
11 torch.Size([200, 2048]) torch.Size([2000, 32])
12 torch.Size([200, 2048]) torch.Size([2000, 32])
13 torch.Size([200, 2048]) torch.Size([2000, 32])
14 torch.Size([200, 2048]) torch.Size([2000, 32])
15 torch.Size([200, 2048]) torch.Size([2000, 32])
16 torch.Size([200, 2048]) torch.Size([2000, 32])
17 torch.Size([200, 2048]) torch.Size([2000, 32])
18 torch.Size([200, 2048]) torch.Size([2000, 32])
19 torch.Size([200, 2048]) torch.Size([2000, 32])
20 torch.S

In [17]:
CUB_test = CUB(RAWDATA_PATH, RAWDATA_PATH, 'test', device, tx)

In [18]:
testing_generator = torch.utils.data.DataLoader(CUB_test, batch_size=2000, shuffle=False)
len(testing_generator)

15

In [19]:
for i, (img, txt) in enumerate(testing_generator):
    print(i, torch.unique(img, dim=0).shape, txt[0].shape)

0 torch.Size([200, 2048]) torch.Size([2000, 32])
1 torch.Size([200, 2048]) torch.Size([2000, 32])
2 torch.Size([200, 2048]) torch.Size([2000, 32])
3 torch.Size([200, 2048]) torch.Size([2000, 32])
4 torch.Size([200, 2048]) torch.Size([2000, 32])
5 torch.Size([200, 2048]) torch.Size([2000, 32])
6 torch.Size([200, 2048]) torch.Size([2000, 32])
7 torch.Size([200, 2048]) torch.Size([2000, 32])
8 torch.Size([200, 2048]) torch.Size([2000, 32])
9 torch.Size([200, 2048]) torch.Size([2000, 32])
10 torch.Size([200, 2048]) torch.Size([2000, 32])
11 torch.Size([200, 2048]) torch.Size([2000, 32])
12 torch.Size([200, 2048]) torch.Size([2000, 32])
13 torch.Size([200, 2048]) torch.Size([2000, 32])
14 torch.Size([133, 2048]) torch.Size([1330, 32])
