In [20]:
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
#from torch_geometric.data import Data
#from torch_geometric.loader import DataLoader

from data.ag.action_genome import AG

from torch import Tensor
import torch.nn.functional as F
import torchvision.transforms as T


from models.rgcn import RGCN
from torchvision.models import vit_b_16, ViT_B_16_Weights, VisionTransformer

%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
root = '/data/Datasets/ag/'
train_set = AG(root, split='train', subset_file='data/ag/subset_shelve')
test_set = AG(root, split='test', subset_file='data/ag/subset_shelve')

split: train length: 6388
split: test length: 1597


In [22]:
train_loader = DataLoader(train_set, batch_size=16, collate_fn=train_set.verb_pred_collate)
test_loader = DataLoader(test_set, batch_size=1, collate_fn=test_set.verb_pred_collate)

In [31]:
%autoreload

class JointModel(nn.Module):
    def __init__(self, vit, rgcn):
        super(JointModel, self).__init__()
        self.vit = vit
        self.rgcn = rgcn
    
    def forward(self, img, sg):
        img = self.vit(img)
        sg = self.rgcn(sg)
        return img, sg

class ViT(nn.Module):
    def __init__(self, num_classes):
        super(ViT, self).__init__()
        vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
        vit.heads = torch.nn.Identity()

        #freeze the backbone
        for param in vit.parameters():
            param.requires_grad = False

        #set the head to use our num classes
        vit.heads = torch.nn.Linear(vit.hidden_dim, num_classes)
        self.vit = vit

    def forward(self, x):
        return self.vit(x)

def train(model, loader, weight, device, epochs=1, lr=1e-2):
    criterion = torch.nn.CrossEntropyLoss(weight=weight)

    model = model.to(device)
    P = model.parameters()

    optimizer = torch.optim.Adam(P, lr=lr)

    for e in range(epochs):
        epoch_loss = 0
        for batch in tqdm(loader):
            ids, imgs, sgs, verbs, labels = batch

            labels = labels.to(device)
            sgs = sgs.to(device)
            imgs = imgs.to(device)

            optimizer.zero_grad()
            if isinstance(model, RGCN):
                out = model(sgs)
            elif isinstance(model, ViT):
                out = model(imgs)
            elif isinstance(model, JointModel):
                out = model(imgs, sgs)
            else:
                raise ValueError(f'Unknown model type: {model.__class__}')

            loss = criterion(out, labels)

            optimizer.step()
            epoch_loss += loss.item()
        print(f'Epoch {e} loss: {epoch_loss/len(train_set)}')

In [32]:
%autoreload
epochs = 10
device = torch.device('cuda:0')

node_feature_size = 32
num_obj_classes = len(train_set.object_classes)
num_verb_classes = len(train_set.verb_classes)
num_rel_classes = len(train_set.relationship_classes)
print(num_obj_classes, num_verb_classes, num_rel_classes)

rgcn = RGCN(num_obj_classes, node_feature_size, num_verb_classes, num_rel_classes)
vit = ViT(num_verb_classes)
joint_model = JointModel(vit, rgcn)

weight = len(train_set) / (num_verb_classes * train_set.verb_label_counts)
weight = torch.tensor(weight, dtype=torch.float).to(device)

train(vit, train_loader, weight, device, epochs=epochs)

36 33 26


100%|██████████| 759/759 [01:50<00:00,  6.85it/s]


Epoch 0 loss: 0.22319689100546064


100%|██████████| 759/759 [01:52<00:00,  6.73it/s]


Epoch 1 loss: 0.22319689100546064


100%|██████████| 759/759 [01:50<00:00,  6.86it/s]


Epoch 2 loss: 0.22319689100546064


  8%|▊         | 63/759 [00:09<01:42,  6.76it/s]


KeyboardInterrupt: 

In [14]:
import random

arr = []
for i in range(10000):
    arr.append(random.randint(0, 9))

torch.bincount(torch.Tensor(arr).int())

        '''
        super(ViT, self).__init__()
        self.vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
        self.out_proj = nn.Linear(1000, num_classes)
        '''


tensor([ 986, 1003, 1042, 1029, 1030, 1004,  999,  938,  993,  976])

In [17]:
import torch
from torchvision.models import vit_b_16, ViT_B_16_Weights

# Load pre-trained ViT model
model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)

model.heads = torch.nn.Identity()
print(model.heads)
print(model.hidden_dim)

# Prepare input
img = torch.randn(1, 3, 224, 224)

# Forward pass
output = model(img)
print(output.shape)


Identity()
768
torch.Size([1, 768])
