In [1]:
from tqdm import tqdm
import tensorboard

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
#from torch_geometric.data import Data
#from torch_geometric.loader import DataLoader

from data.ag.action_genome import AG

from torch import Tensor
import torch.nn.functional as F
import torchvision.transforms as T

from models.rgcn import RGCN
from torchvision.models import vit_b_16, ViT_B_16_Weights, VisionTransformer

import pytorch_lightning as L 
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import torchmetrics

%load_ext autoreload

In [2]:
root = '/data/Datasets/ag/'
train_set = AG(root, split='train', subset_file='data/ag/subset_shelve')
test_set = AG(root, split='test', subset_file='data/ag/subset_shelve')

split: train length: 6388
split: test length: 1597


In [3]:
train_loader = DataLoader(train_set, batch_size=16, collate_fn=train_set.verb_pred_collate, num_workers=8)
test_loader = DataLoader(test_set, batch_size=1, collate_fn=test_set.verb_pred_collate)

In [4]:
%autoreload

class ViT(nn.Module):
    def __init__(self, num_classes):
        super(ViT, self).__init__()
        vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
        vit.heads = torch.nn.Identity()

        #freeze the backbone
        for param in vit.parameters():
            param.requires_grad = False

        #set the head to use our num classes
        vit.heads = torch.nn.Linear(vit.hidden_dim, num_classes)
        self.vit = vit

    def forward(self, x):
        return self.vit(x)

class JointModel(nn.Module):
    def __init__(self, rgcn_params, vit_hidden_dim, num_classes):
        super(JointModel, self).__init__()
        num_obj_classes, node_feature_size, rgcn_hidden_dim, num_rel_classes = rgcn_params
        self.rgcn = RGCN(num_obj_classes, node_feature_size, rgcn_hidden_dim, num_rel_classes)
        self.vit = ViT(vit_hidden_dim)
        self.head = nn.Linear(vit_hidden_dim + rgcn_hidden_dim, num_classes)
    
    def forward(self, img, sg):
        img = self.vit(img)
        sg = self.rgcn(sg)
        hidden_state = torch.cat((img, sg), dim=1)
        return F.softmax(self.head(hidden_state), dim=-1)


In [None]:

def train(model, loader, weight, device, epochs=1, lr=1e-2):
    criterion = torch.nn.CrossEntropyLoss(weight=weight)

    model = model.to(device)
    P = model.parameters()

    optimizer = torch.optim.Adam(P, lr=lr)

    for e in range(epochs):
        epoch_loss = 0
        total = 0
        correct = 0
        for batch in tqdm(loader):
            ids, imgs, sgs, verbs, labels = batch

            labels = labels.to(device)
            sgs = sgs.to(device)
            imgs = imgs.to(device)

            optimizer.zero_grad()
            if isinstance(model, RGCN):
                out = model(sgs)
            elif isinstance(model, ViT):
                out = model(imgs)
            elif isinstance(model, JointModel):
                out = model(imgs, sgs)
            else:
                raise ValueError(f'Unknown model type: {model.__class__}')

            loss = criterion(out, labels)

            # Calculate accuracy
            _, predicted = torch.max(out, 1)
            _, labels_1d = torch.max(labels, 1)

            total += labels.size(0)
            correct += (predicted == labels_1d).sum().item()

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f'Epoch {e} loss: {epoch_loss} accuracy: {correct/total}')

In [13]:
%autoreload
epochs = 10
device = torch.device('cuda:0')

node_feature_size = 32
num_obj_classes = len(train_set.object_classes)
num_verb_classes = len(train_set.verb_classes)
num_rel_classes = len(train_set.relationship_classes)
print(num_obj_classes, num_verb_classes, num_rel_classes)

rgcn = RGCN(num_obj_classes, node_feature_size, num_verb_classes, num_rel_classes)
vit = ViT(num_verb_classes)

rgcn_hidden_dim, vit_hidden_dim = 32, 32
rgcn_params = (num_obj_classes, node_feature_size, rgcn_hidden_dim, num_rel_classes)
joint_model = JointModel(rgcn_params, vit_hidden_dim, num_verb_classes)

weight = len(train_set) / (num_verb_classes * train_set.verb_label_counts)
weight = torch.tensor(weight, dtype=torch.float).to(device)

train(joint_model, train_loader, weight, device, epochs=epochs)

36 33 26


100%|██████████| 718/718 [02:31<00:00,  4.74it/s]


Epoch 0 loss: 3475.786923289299 accuracy: 0.04938164082912384


100%|██████████| 718/718 [02:34<00:00,  4.64it/s]


Epoch 1 loss: 4240.812429666519 accuracy: 0.06775823027347153


100%|██████████| 718/718 [02:40<00:00,  4.47it/s]


Epoch 2 loss: 4299.314756035805 accuracy: 0.08195436335133252


100%|██████████| 718/718 [02:32<00:00,  4.71it/s]


Epoch 3 loss: 2834.35235619545 accuracy: 0.10102769552342797


100%|██████████| 718/718 [02:45<00:00,  4.33it/s]


Epoch 4 loss: 2970.202409505844 accuracy: 0.11383034314579342


100%|██████████| 718/718 [02:31<00:00,  4.73it/s]


Epoch 5 loss: 2792.954768061638 accuracy: 0.10921442257446438


100%|██████████| 718/718 [02:42<00:00,  4.41it/s]


Epoch 6 loss: 3402.575664460659 accuracy: 0.10912732973349591


100%|██████████| 718/718 [02:37<00:00,  4.57it/s]


Epoch 7 loss: 3129.4967131614685 accuracy: 0.10755965859606341


100%|██████████| 718/718 [02:50<00:00,  4.20it/s]


Epoch 8 loss: 3237.4125183820724 accuracy: 0.11487545723741509


100%|██████████| 718/718 [02:33<00:00,  4.67it/s]

Epoch 9 loss: 3269.2235572338104 accuracy: 0.1145270858735412





# trying pytorch lightning

In [5]:
# trying pytorch lightning

class JointModelLightning(L.LightningModule):
    def __init__(self, model_params, weight):
        super().__init__()
        rgcn_params, vit_hidden_dim, num_classes = model_params 
        self.model = JointModel(rgcn_params, vit_hidden_dim, num_classes)
        self.weight = weight
        
        # Initialize multiple accuracy metrics
        self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.train_acc_epoch = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        
        self.save_hyperparameters()

        
    def forward(self, img, sg):
        return self.model(img, sg)
    
    def training_step(self, batch, batch_idx):
        # Same batch unpacking as original
        ids, imgs, sgs, verbs, labels = batch
        
        # Forward pass is simpler - no need to move tensors to device 
        # since Lightning handles that automatically
        out = self(imgs, sgs)

        
        # Using simple cross entropy instead of custom WeightedCELoss
        loss = F.cross_entropy(out, labels, weight=self.weight)
        
        # Using Lightning's built-in accuracy metric
        # Accuracy is calculated per batch and automatically aggregated by Lightning
        acc = self.train_accuracy(out, labels)
        self.train_acc_epoch.update(out, labels)
        
        # New: Lightning's built-in logging instead of manual logging
        # Progress bar shows metrics during training
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        
        # Only return loss - Lightning handles the rest
        return loss
    
    def on_train_epoch_end(self):
        # Log the final metrics for the epoch
        self.log('train_acc_epoch', self.train_acc_epoch.compute())
        
        # Reset metrics at the end of each epoch
        self.train_accuracy.reset()
        self.train_acc_epoch.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-2)


In [6]:
%autoreload
epochs = 10
device = torch.device('cuda:0')

node_feature_size = 32
num_obj_classes = len(train_set.object_classes)
num_verb_classes = len(train_set.verb_classes)
num_rel_classes = len(train_set.relationship_classes)
print(num_obj_classes, num_verb_classes, num_rel_classes)

rgcn_hidden_dim, vit_hidden_dim = 32, 32
rgcn_params = (num_obj_classes, node_feature_size, rgcn_hidden_dim, num_rel_classes)

model_params = (rgcn_params, vit_hidden_dim, num_verb_classes)

weight = len(train_set) / (num_verb_classes * train_set.verb_label_counts)
weight = torch.tensor(weight, dtype=torch.float).to(device)


36 33 26


In [8]:
# Initialize model and trainer
lightning_model = JointModelLightning(model_params, weight)

# Setup callbacks and logger
checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',
    dirpath='checkpoints/',
    filename='joint-model-{epoch:02d}-{train_loss:.2f}',
    save_top_k=3,
    mode='min',
)

logger = TensorBoardLogger("lightning_logs", name="joint_model")

trainer = L.Trainer(
    max_epochs=10,
    accelerator='gpu',
    devices=[0],
    callbacks=[checkpoint_callback],
    logger=logger,
)

# Train the model
trainer.fit(lightning_model, train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4]

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | model           | JointModel         | 85.9 M | train
1 | train_accuracy  | MulticlassAccuracy | 0      | train
2 | train_acc_epoch | MulticlassAccuracy | 0      | train
---------------------------------------------------------------
56.7 K    Trainable params
85.8 M    Non-trainable params
85.9 M    Total params
343.421   Total estimated model params size (MB)
164       Modules in train mode
0         Modules in eval mode


Epoch 0:   0%|          | 0/718 [00:00<?, ?it/s] 


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined