In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, TensorDataset, DataLoader

import os
from tqdm import tqdm

from dataset.dataset import (
    MultimodalPretrainedEmbeddingsDatasetLoader, 
    MultimodalPretrainedEmbeddingsDataset, 
)

In [9]:
saved_embeddings_dir = '/vol/bitbucket/jq619/individual-project/saved_embeddings'

text_embeds_raw_dir = os.path.join(saved_embeddings_dir, 'text_embeds', 'ClinicalBERT')
image_embeds_raw_dir = os.path.join(saved_embeddings_dir, 'image_embeds', 'Swin-Base')

dataset_loader = MultimodalPretrainedEmbeddingsDatasetLoader(text_embeds_raw_dir, image_embeds_raw_dir, 
                                                             split='train', num_of_batches=50)

In [10]:
dataset = dataset_loader.load_data()

100%|██████████| 50/50 [00:03<00:00, 12.56it/s]


In [13]:
from datasets import Dataset

train_dataset = Dataset.from_dict(dataset)

In [25]:
from argparse import Namespace
from models.adaptor import Adaptor
from models.configurations import (
    TEXT_PRETRAINED_AVAILABLE,
    VISION_PRETRAINED_AVAILABLE,
    VISION_MODEL_TYPE_2_DATA_TRANSFORM,
    VISION_MODEL_TYPE_2_VISION_OUTPUT_DIM, 
)
from utils.utils import load_timm_model, freeze_encoder
from utils.model_utils import load_vision_model
from transformers import AutoTokenizer
from transformers import BertModel, AutoModel, ViTImageProcessor

from transformers import TrainingArguments, Trainer

args = Namespace(
    **{'batch_size': 16,
    'vision_pretrained': 'swin_base_patch4_window7_224', 
    'vision_model_type': 'timm', 
    'text_pretrained': './weights/ClinicalBERT_checkpoint/ClinicalBERT_pretraining_pytorch_checkpoint',
    'num_train_epochs':1, 
    'lr': 1e-4,
    'projection_dim': 768,
    'num_hidden_layers': 1,
    'seed':1117, }  
)

In [34]:
class CustomTrainer(Trainer):
    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training [`~torch.utils.data.DataLoader`].

        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        from transformers.trainer_utils import seed_worker
        
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator

        train_sampler = self._get_train_sampler()

        return DataLoader(
            train_dataset,
            batch_size=self._train_batch_size,
            sampler=train_sampler,
            collate_fn=data_collator,
            drop_last=self.args.dataloader_drop_last,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            worker_init_fn=seed_worker,
        )

In [39]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load vision model
if args.vision_pretrained in VISION_PRETRAINED_AVAILABLE.keys():
    assert VISION_PRETRAINED_AVAILABLE[args.vision_pretrained] == args.vision_model_type, \
        'Vision model type does not match pretrained model'
vision_model = load_vision_model(args.vision_model_type, args.vision_pretrained)

### Load text model
text_model = BertModel.from_pretrained(args.text_pretrained)
tokenizer = AutoTokenizer.from_pretrained(args.text_pretrained)

### Define model
add_cls_token = args.vision_model_type == 'ae'
vision_output_dim = VISION_MODEL_TYPE_2_VISION_OUTPUT_DIM[args.vision_model_type]
model = Adaptor(
    text_model=text_model,
    vision_model=vision_model,
    vision_model_type=args.vision_model_type, 
    vision_output_dim=vision_output_dim,
    projection_dim=args.projection_dim,
    num_hidden_layers=args.num_hidden_layers, 
    add_cls_token=add_cls_token,
)
freeze_encoder(model)  # freeze encoder
model = nn.DataParallel(model)
model.to(device)


### Training
arguments = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=args.batch_size, 
    per_device_eval_batch_size=args.batch_size,  
    num_train_epochs=args.num_train_epochs,
    logging_steps=20, 
    save_strategy="epoch",
    learning_rate=args.lr, 
    seed=args.seed, 
    push_to_hub=False, 
)

trainer = CustomTrainer(
    model=model, 
    args=arguments,
    train_dataset=train_dataset, 
    # eval_dataset=val_dataset, 
    # tokenizer=tokenizer, 
    data_collator=None, 
)
trainer.train()

Some weights of the model checkpoint at ./weights/ClinicalBERT_checkpoint/ClinicalBERT_pretraining_pytorch_checkpoint were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Step,Training Loss
20,2.7804
40,2.7678
60,2.3801
80,1.651
100,1.4545
120,1.2626
140,1.1937
160,1.1436
180,1.0903
200,1.101


TrainOutput(global_step=400, training_loss=1.3243766689300538, metrics={'train_runtime': 589.8209, 'train_samples_per_second': 10.851, 'train_steps_per_second': 0.678, 'total_flos': 0.0, 'train_loss': 1.3243766689300538, 'epoch': 1.0})