## Configs

In [None]:
from torchvision import transforms
import pytorch_lightning as pl
from multidata import RALO_Datasets
from multinet import MultiModel
import torch.nn as nn
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint

architecture = 'vit_tiny_patch16_224'
loss_func = nn.L1Loss()
batch_size = 32
num_epochs = 60
num_workers = 4

## Data

In [None]:
transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

train = RALO_Datasets(
                        imgpath="covid-severity/Lung_Rep/",
                        csvpath="covid-severity/lung_rep.csv",
                        subset="train",
                        transform=transform)
# print(train)
test = RALO_Datasets(
                        imgpath="covid-severity/Lung_Rep/",
                        csvpath="covid-severity/lung_rep.csv",
                        subset="test",
                        transform)

print(train)
print(test)

test_loader = DataLoader(dataset=test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
train_loader = DataLoader(dataset=train, batch_size=batch_size, shuffle=True, num_workers=num_workers)

RALO_Dataset(train): 5436 images
RALO_Dataset(test): 494 images


## Model

In [None]:
from timm.models import create_model


class Model:
    def __init__(self, architecture='vit_tiny_patch16_224', loss_func=nn.L1Loss()):
        super().__init__()

        self.model = timm.create_model(architecture, pretrained=True, num_classes=1)
        self.loss_func = loss_func

        self.model.norm = nn.Identity()
        self.model.pre_legits = nn.Identity()
        self.model.head = nn.Sequential(nn.Linear(192,128),nn.Linear(128,2))


    def forward(self, images):
        return self.model(images)

In [None]:
model = Model(architecture, loss_func)

g_chk_path = '/home/bslika/Downloads/ge.ckpt'
o_chk_path = '/home/bslika/Downloads/lo.ckpt'

modelg = model.load_from_checkpoint(g_chk_path)
modelo = model.load_from_checkpoint(o_chk_path)

In [None]:
multi_model= MultiModel(modelg,modelo,loss_func)

## Training/Testing

In [None]:
trainer = pl.Trainer(gpus=1,
                    max_epochs=num_epochs,
                    progress_bar_refresh_rate=50,
                    benchmark=True)

trainer.fit(multi_model, train_loader,test_loader)

# Draw Results

In [None]:
trainer.test(multi_model, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Loss/Test': 0.7263020277023315,
 'MAE_SDVg/Test': 0.09035155177116394,
 'MAE_SDVo/Test': 0.06867062300443649,
 'MAEg/Test': 0.5804502367973328,
 'MAEo/Test': 0.486172616481781,
 'PCg/Test': 0.9303927421569824,
 'PCo/Test': 0.8863729238510132}
--------------------------------------------------------------------------------


[{'Loss/Test': 0.7263020277023315,
  'PCg/Test': 0.9303927421569824,
  'MAEg/Test': 0.5804502367973328,
  'PCo/Test': 0.8863729238510132,
  'MAEo/Test': 0.486172616481781,
  'MAE_SDVg/Test': 0.09035155177116394,
  'MAE_SDVo/Test': 0.06867062300443649}]

In [None]:
import matplotlib.pyplot as plt

train_loss = [loss.detach().cpu().item() for loss in model.tr_loss]
valid_loss = [loss.detach().cpu().item() for loss in model.vl_loss]
test_loss = [loss.detach().cpu().item() for loss in model.ts_loss]

train_mae = [loss.detach().cpu().item() for loss in model.tr_mae]
valid_mae = [loss.detach().cpu().item() for loss in model.vl_mae]
test_mae = [loss.detach().cpu().item() for loss in model.ts_mae]

epochs = range(num_epochs)
vepochs = range(num_epochs+1)

train_loss_plt, = plt.plot(epochs, train_loss, label='Train')
valid_loss_plt, = plt.plot(vepochs, valid_loss, label='Test')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(handles=[train_loss_plt,valid_loss_plt], loc='upper right')

# Load pretrained weights for model

In [None]:
multi_model= MultiModel(modelg,modelo,loss_func)
chk_path = 'multi_weights/checkpoints/epoch=59-step=10200.ckpt'
loaded_model = multi_model.load_from_checkpoint(chk_path)