## Dataset

In [2]:
import torch
import torch.nn.functional as F
import numpy as np
import os
from omegaconf import OmegaConf

from dataloader.dataset import CLIP_COCO_dataset
from dataloader.data_loaders import get_dataloader

from model.model import CLIP
from utils.simple_tokenizer import SimpleTokenizer
from utils.custom_schedulers import get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup
from utils import set_seed, mkdir, setup_logger, load_config_file

from torch.optim import Adam, AdamW # both are same but AdamW has a default weight decay

import argparse

DATA_CONFIG_PATH = 'dataloader/data_config.yaml'
TRAINER_CONFIG_PATH = 'trainer/train_config.yaml'
MODEL_CONFIG_PATH = 'model/model_config.yaml'

data_config = load_config_file(DATA_CONFIG_PATH)
train_config = load_config_file(TRAINER_CONFIG_PATH)
model_config = load_config_file(MODEL_CONFIG_PATH)
config = OmegaConf.merge(train_config, data_config)

tokenizer = SimpleTokenizer()
train_dataset = CLIP_COCO_dataset(config, tokenizer)

train_dataloader = get_dataloader(config, train_dataset, is_train=True)

# dataloader的size
data_iter = iter(train_dataloader)

# 查看dataloader的维度

#print(len(train_dataloader))


a = next(data_iter)
#print(a[1][0])
#print(a[1][1])

"""
for _ in range(1849):
    a = next(data_iter)
    if(_ % 100 == 0):
        print(_)
print(a[0].shape)
"""

'\nfor _ in range(1849):\n    a = next(data_iter)\n    if(_ % 100 == 0):\n        print(_)\nprint(a[0].shape)\n'

## Train with one image-text dataset

In [3]:
batch = a

input_images, input_texts = batch

#config.device = "cuda" if torch.cuda.is_available() else "cpu"
#input_images = input_images.to(torch.device(config.device))
#input_texts = input_texts.to(torch.device(config.device))



model_params = dict(model_config.RN50)
model_params['vision_layers'] = tuple(model_params['vision_layers'])
model_params['vision_patch_size'] = None
model = CLIP(**model_params)
print(model_params)

#model = model.to(torch.device(config.device))
#model.train()

image_features, text_features = model(input_images, input_texts)

print('input_images:',input_images.shape,'\n', 'image_features:', image_features.shape, '\n','\n'
        'input_texts:',input_texts.shape, '\n','text_features:',text_features.shape)





{'embed_dim': 1024, 'image_resolution': 224, 'vision_layers': (3, 4, 6, 3), 'vision_width': 64, 'vision_patch_size': None, 'context_length': 77, 'vocab_size': 49408, 'transformer_width': 512, 'transformer_heads': 8, 'transformer_layers': 6}
x.shape1: torch.Size([64, 77, 512])
y.shape: torch.Size([64, 512])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
x.shape2: torch.Size([64, 1024])
input_images: torch.Size([64, 3, 224, 224]) 
 image_features: torch.Size([64, 1024]) 
 
input_texts: torch.Size([64, 77]) 
 text_features: torch.Size([64, 1024])


In [18]:
print(input_texts.argmax(dim=-1))
print(input_texts[1])
#print(torch.arange(64))

test = torch.arange(2523136).reshape([64,77,512])
test = test[torch.arange(64), input_texts.argmax(dim=-1)]
print(test[1])

tensor([10, 17, 11, 12, 14, 12, 11, 18, 10, 22, 11, 12, 10, 10, 13, 14, 11, 15,
        10, 13, 12, 12, 12, 10, 12, 14, 12, 10, 12, 11, 11, 13, 12, 11, 11, 10,
        10, 15, 18, 12, 12, 10, 13, 14, 16, 12, 16, 13, 14, 14, 11, 11, 14, 14,
        14, 10, 22, 11, 10,  9, 12, 14, 11, 15])
tensor([49406,   320,  4038,   539,  2254, 22611,  2374,   530,   518, 11795,
          530,   911, 11747,   536,   320,  7147,   269, 49407,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0])
tensor([48128, 48129, 48130, 48131, 48132, 48133, 48134, 48135, 48136, 48137,
        48138, 48139, 48140, 