In [40]:
from tqdm import tqdm
import tensorboard

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
#from torch_geometric.data import Data
#from torch_geometric.loader import DataLoader

from data.ag.action_genome import AG

from torch import Tensor
import torch.nn.functional as F
import torchvision.transforms as T

from models.rgcn import RGCN
from torchvision.models import vit_b_16, ViT_B_16_Weights, VisionTransformer

import pytorch_lightning as L 
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import torchmetrics


import warnings
warnings.filterwarnings("ignore")
#warnings.filterwarnings("default")

%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
root = '/data/Datasets/ag/'
train_set = AG(root, split='train', subset_file='data/ag/subset_shelve')
val_set = AG(root, split='test', subset_file='data/ag/subset_shelve')

split: train length: 6388
split: test length: 1597


In [41]:
train_loader = DataLoader(train_set, batch_size=16, collate_fn=train_set.verb_pred_collate, num_workers=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=128, collate_fn=val_set.verb_pred_collate, num_workers=16, shuffle=True)

In [42]:
%autoreload

class ViT(nn.Module):
    def __init__(self, num_classes, head=True):
        super(ViT, self).__init__()
        vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
        vit.heads = torch.nn.Identity()
        self.head = head

        #freeze the backbone
        for param in vit.parameters():
            param.requires_grad = False

        #set the head to use our num classes
        vit.heads = torch.nn.Linear(vit.hidden_dim, num_classes)
        self.vit = vit

    def forward(self, x):
        x = self.vit(x)
        if self.head:
            return F.softmax(x, dim=-1)
        return x

class JointModel(nn.Module):
    def __init__(self, rgcn_params, vit_hidden_dim, num_classes):
        super(JointModel, self).__init__()
        num_obj_classes, node_feature_size, rgcn_hidden_dim, num_rel_classes = rgcn_params
        self.rgcn = RGCN(num_obj_classes, node_feature_size, rgcn_hidden_dim, num_rel_classes, head=False)
        self.vit = ViT(vit_hidden_dim, head=False)
        self.head = nn.Linear(vit_hidden_dim + rgcn_hidden_dim, num_classes)
    
    def forward(self, img, sg):
        img = self.vit(img)
        sg = self.rgcn(sg)
        hidden_state = torch.cat((img, sg), dim=1)
        return F.softmax(self.head(hidden_state), dim=-1)


# trying pytorch lightning

In [46]:
# trying pytorch lightning

class JointModelLightning(L.LightningModule):
    def __init__(self, model_params, weight, model_type='joint'):
        super().__init__()
        self.model_type = model_type
        rgcn_params, vit_hidden_dim, num_classes = model_params 
        if model_type == 'joint':
            self.model = JointModel(rgcn_params, vit_hidden_dim, num_classes)
        elif model_type == 'rgcn':
            num_obj_classes, node_feature_size, rgcn_hidden_dim, num_rel_classes = rgcn_params
            self.model = RGCN(num_obj_classes, node_feature_size, num_classes, num_rel_classes, head=True)
        elif model_type == 'vit':
            self.model = ViT(num_classes, head=True)
        self.weight = weight
        
        #epoch metrics
        self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
                
        self.save_hyperparameters()

        
    def forward(self, img, sg):
        if self.model_type == 'rgcn':
            return self.model(sg)
        elif self.model_type == 'vit':
            return self.model(img)
        else:
            return self.model(img, sg)
    
    def training_step(self, batch, batch_idx):
        ids, imgs, sgs, verbs, labels = batch
        out = self(imgs, sgs)
        
        loss = F.cross_entropy(out, labels, weight=self.weight)
        acc = self.train_accuracy(out, labels)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        ids, imgs, sgs, verbs, labels = batch
        out = self(imgs, sgs)
        
        val_loss = F.cross_entropy(out, labels, weight=self.weight)
        val_acc = self.val_accuracy(out, labels) 
        
        self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_acc', val_acc, on_step=False, on_epoch=True, prog_bar=True)
            
        return val_loss
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-2)


In [47]:
%autoreload
epochs = 10
device = torch.device('cuda:0')

node_feature_size = 32
num_obj_classes = len(train_set.object_classes)
num_verb_classes = len(train_set.verb_classes)
num_rel_classes = len(train_set.relationship_classes)
print(num_obj_classes, num_verb_classes, num_rel_classes)

rgcn_hidden_dim, vit_hidden_dim = 32, 32
rgcn_params = (num_obj_classes, node_feature_size, rgcn_hidden_dim, num_rel_classes)
model_params = (rgcn_params, vit_hidden_dim, num_verb_classes)

weight = len(train_set) / (num_verb_classes * train_set.verb_label_counts)
weight = torch.tensor(weight, dtype=torch.float).to(device)
print(weight)

36 33 26
tensor([ 1.9013,  0.5761,  4.8325,  3.9094,  1.1152,  0.8528,  3.0791, 12.8866,
         0.5155,  1.1297,  5.6119, 57.9899,  0.4525,  5.3529,  1.6890,  4.3492,
         0.1672,  2.0835,  0.4018,  0.5585,  0.9829,  2.7397,  0.4349,  0.7801,
         8.2843,  0.7872,  0.8507, 10.2335,  2.8995,  1.3331,  2.2162,  0.7419,
         3.3137], device='cuda:0')


In [48]:
model_type = 'rgcn'

# Initialize model and trainer
lightning_model = JointModelLightning(model_params, weight, model_type='rgcn')

# Setup callbacks and logger
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints/',
    filename='joint-model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

logger = TensorBoardLogger("lightning_logs", name=f"{model_type}_model")

trainer = L.Trainer(
    max_epochs=10,
    accelerator='gpu',
    devices=[0],
    callbacks=[checkpoint_callback],
    logger=logger,
)

# Train the model
trainer.fit(lightning_model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4]

  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | model          | RGCN               | 30.0 K | train
1 | train_accuracy | MulticlassAccuracy | 0      | train
2 | val_accuracy   | MulticlassAccuracy | 0      | train
--------------------------------------------------------------
30.0 K    Trainable params
0         Non-trainable params
30.0 K    Total params
0.120     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Epoch 9: 100%|██████████| 718/718 [00:42<00:00, 16.77it/s, v_num=11, train_loss_step=3.480, val_loss_step=7.230, val_loss_epoch=3.670, val_acc=0.928, train_loss_epoch=3.450, train_acc=0.913]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 718/718 [00:42<00:00, 16.77it/s, v_num=11, train_loss_step=3.480, val_loss_step=7.230, val_loss_epoch=3.670, val_acc=0.928, train_loss_epoch=3.450, train_acc=0.913]


In [15]:
import torch_geometric.data.collate