In [1]:
from tqdm import tqdm
import tensorboard

import numpy as np
import torch

from torch.utils.data import DataLoader
from data.ag.action_genome import AG

import torch.nn.functional as F

from models.rgcn import RGCN
from models.vit import ViT
from models.joint_model import JointModel

import pytorch_lightning as L 
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import torchmetrics

from pyswip import Prolog

import warnings
warnings.filterwarnings("ignore")
#warnings.filterwarnings("default")


In [2]:
root = '/data/Datasets/ag/'
train_set = AG(root, split='train', split_file='data/ag/split_train_val_test.json', subset_file='data/ag/subset_shelve')
val_set = AG(root, split='val', split_file='data/ag/split_train_val_test.json', subset_file='data/ag/subset_shelve')
test_set = AG(root, split='test', split_file='data/ag/split_train_val_test.json', subset_file='data/ag/subset_shelve')

split: train length: 6388
split: val length: 798
split: test length: 799


In [3]:
train_loader = DataLoader(train_set, batch_size=16, collate_fn=train_set.verb_pred_collate, num_workers=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=128, collate_fn=val_set.verb_pred_collate, num_workers=16, shuffle=False)
test_loader = DataLoader(test_set, batch_size=128, collate_fn=test_set.verb_pred_collate, num_workers=16, shuffle=False)

# trying pytorch lightning

In [5]:
# trying pytorch lightning

class JointModelLightning(L.LightningModule):
    def __init__(self, model_params, weight, model_type='joint'):
        super().__init__()
        self.model_type = model_type
        rgcn_params, vit_hidden_dim, num_classes = model_params 
        if model_type == 'joint':
            self.model = JointModel(rgcn_params, vit_hidden_dim, num_classes)
        elif model_type == 'rgcn':
            num_obj_classes, node_feature_size, rgcn_hidden_dim, num_rel_classes = rgcn_params
            self.model = RGCN(num_obj_classes, node_feature_size, num_classes, num_rel_classes, head=True)
        elif model_type == 'vit':
            self.model = ViT(num_classes, head=True)
        self.weight = weight
        self.constraints = None
        
        #epoch metrics
        self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.test_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        
        self.save_hyperparameters()
        
    def forward(self, img, sg):
        if self.model_type == 'rgcn':
            return self.model(sg)
        elif self.model_type == 'vit':
            return self.model(img)
        else:
            return self.model(img, sg)
    
    def training_step(self, batch, batch_idx):
        ids, imgs, sgs, verbs, labels = batch
        out = self(imgs, sgs)
        
        loss = F.cross_entropy(out, labels, weight=self.weight)
        out, labels = torch.argmax(out, dim=1), torch.argmax(labels, dim=1)
        acc = self.train_accuracy(out, labels)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss

    
    def validation_step(self, batch, batch_idx):
        ids, imgs, sgs, verbs, labels = batch
        out = self(imgs, sgs)
        out, labels = torch.argmax(out, dim=1), torch.argmax(labels, dim=1)
        val_acc = self.val_accuracy(out, labels) 
        
        #self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_acc', val_acc, on_step=False, on_epoch=True, prog_bar=True)
            
        #return val_loss

    def test_step(self, batch, batch_idx):
        ids, imgs, sgs, verbs, labels = batch
        out = self(imgs, sgs)
        
        test_loss = F.cross_entropy(out, labels, weight=self.weight)
        out, labels = torch.argmax(out, dim=1), torch.argmax(labels, dim=1)
        test_acc = self.test_accuracy(out, labels) 
        
        self.log('test_loss', test_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('test_acc', test_acc, on_step=False, on_epoch=True, prog_bar=True)
            
        return test_loss
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-2)


In [6]:
epochs = 10
device = torch.device('cuda:0')

node_feature_size = 32
num_obj_classes = len(train_set.object_classes)
num_verb_classes = len(train_set.verb_classes)
num_rel_classes = len(train_set.relationship_classes)
print(num_obj_classes, num_verb_classes, num_rel_classes)

rgcn_hidden_dim, vit_hidden_dim = 32, 32
rgcn_params = (num_obj_classes, node_feature_size, rgcn_hidden_dim, num_rel_classes)
model_params = (rgcn_params, vit_hidden_dim, num_verb_classes)

print(train_set.verb_label_counts)
weight = len(train_set) / (num_verb_classes * train_set.verb_label_counts)
weight = torch.tensor(weight, dtype=torch.float).to(device)
print(weight)

36 33 26
[ 183  585   74   93  314  394  114   28  655  335   64    8  765   64
  227   85 2052  161  869  636  356  122  799  442   44  442  392   27
  129  272  172  488  103]
tensor([ 1.9033,  0.5954,  4.7068,  3.7452,  1.1092,  0.8840,  3.0553, 12.4394,
         0.5318,  1.0397,  5.4422, 43.5379,  0.4553,  5.4422,  1.5344,  4.0977,
         0.1697,  2.1634,  0.4008,  0.5476,  0.9784,  2.8549,  0.4359,  0.7880,
         7.9160,  0.7880,  0.8885, 12.9001,  2.7000,  1.2805,  2.0250,  0.7137,
         3.3816], device='cuda:0')


In [None]:
model_type = 'vit'

# Initialize model and trainer
lightning_model = JointModelLightning(model_params, weight, model_type=model_type)

# Setup callbacks and logger
checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    dirpath='checkpoints/',
    filename='joint-model-{epoch:02d}-{val_acc:.2f}',
    save_top_k=3,
    mode='max',
)

logger = TensorBoardLogger("lightning_logs", name=f"{model_type}_model")

trainer = L.Trainer(
    max_epochs=10,
    accelerator='gpu',
    devices=[0],
    callbacks=[checkpoint_callback],
    logger=logger,
)

# Train the model
trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX 6000 Ada Generation') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4]

  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | model          | ViT                | 85.8 M | train
1 | train_accuracy | MulticlassAccuracy | 0      | train
2 | val_accuracy   | MulticlassAccuracy | 0      | train
3 | test_accuracy  | MulticlassAccuracy | 0      | train
--------------------------------------------------------------
25.4 K    Trainable params
85.8 M    Non-trainable 

Epoch 7:  29%|██▉       | 207/719 [00:16<00:41, 12.23it/s, v_num=0, train_loss_step=5.090, val_acc=0.0459, train_loss_epoch=3.440, train_acc=0.0646]

In [55]:
# test the model
trainer.test(dataloaders=test_loader)

Restoring states from the checkpoint path at /home/muyang/learned-affordance-constraints/checkpoints/joint-model-epoch=06-val_acc=0.16.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4]
Loaded model weights from the checkpoint at /home/muyang/learned-affordance-constraints/checkpoints/joint-model-epoch=06-val_acc=0.16.ckpt


Testing DataLoader 0: 100%|██████████| 11/11 [00:00<00:00, 18.99it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.1647482067346573
     test_loss_epoch         3.276085376739502
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss_epoch': 3.276085376739502, 'test_acc': 0.1647482067346573}]

In [None]:
# test the model with constraints
trainer.test(dataloaders=test_loader)



In [7]:
def test_rules(rules_file, bk_file, test_size, targets, labels=None):
    print('Testing learned rules=====================')
    preds = []

    _ = Prolog()

    Prolog.consult(rules_file)
    Prolog.consult(bk_file)

    for i in range(test_size):
        pred = np.zeros(len(targets))
        for j,v in enumerate(targets):
            q = Prolog.query(f'{v}_target(x{i}_0)')
            for q in q:
                pred[j] = 1
                break
        pred = pred.astype(int)
        preds.append(pred)

    preds = np.stack(preds)
    if labels is not None:
        #metrics(labels, preds)
        pass
    return preds


In [None]:
masks = test_rules('outputs/ag/rules_learned.pl', 'prolog/ag/test_bk.pl', len(test_set), test_set.verb_classes)
lightning_model.test_constraints = masks #list of verb masks, one for each example





   Call: (1) pyrun("consult('outputs/ag/rules_learned.pl')", _6216) ? 