In [22]:
import os
os.chdir('/vol/bitbucket/jq619/individual-project/')

import torch

from pytorch_lightning import Trainer, seed_everything

from mgca.datasets.classification_dataset import RSNAImageDataset, COVIDXImageDataset
from mgca.datasets.transforms import DataTransforms
from models.adaptor import Adaptor, StreamingProgressBar
from models.pipeline import AdaptorPipelineWithClassificationHead
from models.configurations import TEXT_PRETRAINED, VISION_PRETRAINED
from utils.model_utils import load_vision_model
from utils.dataset_utils import torch2huggingface_dataset, get_dataloader

from math import ceil
import argparse 
import logging

In [23]:
vision_model = 'resnet-ae'
vision_model_config = VISION_PRETRAINED[vision_model]
vision_pretrained = vision_model_config['pretrained_weight']
vision_model_type = vision_model_config['vision_model_type']
vision_output_dim = vision_model_config['vision_output_dim']
data_transform = vision_model_config['data_transform']

In [24]:
imsize = 256

image_dataset = RSNAImageDataset(
    split='train', 
    transform=data_transform(True, imsize), 
    phase='classification', 
    data_pct=1., 
    imsize=imsize, 
)

Setting XRayResizer engine to cv2 could increase performance.


In [25]:
batch_size = 16
max_steps = ceil(len(image_dataset) / batch_size)
train_dataset = torch2huggingface_dataset(image_dataset, streaming=False)
train_dataset.with_format('torch')

<datasets.iterable_dataset.IterableDataset at 0x7fb7256d88e0>

In [26]:
train_dataloader = get_dataloader(
    train_dataset, 
    batch_size=batch_size,
    num_workers=8,
    collate_fn=None,
)

In [27]:
seed_everything(42)

trainer = Trainer(
    accelerator="cpu",
    max_epochs=1,
    # max_steps=args.max_steps,
    log_every_n_steps=10, 
    check_val_every_n_epoch=1, 
    default_root_dir='./trained_models/clf',
    callbacks=[StreamingProgressBar(total=max_steps)],
    enable_progress_bar=False, 
)

Global seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [28]:
model = AdaptorPipelineWithClassificationHead(
    text_model='biobert', 
    vision_model='resnet-ae', 
    adaptor_ckpt='/vol/bitbucket/jq619/individual-project/results/resnet-ae_biobert/lightning_logs/version_76005/checkpoints/epoch=29-step=53279.ckpt', 
    num_classes=2, 
)

In [None]:
trainer.fit(model, train_dataloader)


  | Name         | Type             | Params
--------------------------------------------------
0 | vision_model | _ResNetAE        | 31.1 M
1 | text_model   | BertModel        | 108 M 
2 | adaptor      | Adaptor          | 8.7 M 
3 | classifier   | Linear           | 1.5 K 
4 | loss_func    | CrossEntropyLoss | 0     
--------------------------------------------------
8.7 M     Trainable params
139 M     Non-trainable params
148 M     Total params
592.375   Total estimated model params size (MB)


Epoch 0:   0%|          | 0/1168 [00:00<?, ?it/s] 