#### Import Package

In [None]:
import os
import json
import torch
import numpy as np
import random
import torch.multiprocessing as mp
import matplotlib.pyplot as plt

from torch.backends import cudnn
from class_train.dlip_retrieval_train import DLIPRetrievalTrain
from utils.system import get_config

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
print(torch.__version__)

#### Get GPU

##### Run this command to use CUDA 11.8 or Above (I have CUDA 12.1 installed)
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
# Get Device and Number of Devices
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    num_devices = torch.cuda.device_count()
    for i in range(num_devices):
        print(f"CUDA Device {i}: {torch.cuda.get_device_name(i)}")
else:
    print(f"Using: {device}")

In [None]:
# fix the seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True

#### Train

In [None]:
# Params
param = json.load(open(get_config() / 'train_retrieval_config.json'))

# DLIPRetrievalTrain
dlip_retrieval_train = DLIPRetrievalTrain(azure=param['azure'], 
                                           connection_string=param['connection_string'],
                                           container_name=param['container_name'], 
                                           multi=param['multi'], 
                                           ddp_server=param['ddp_server'],
                                           queue_size=param['queue_size'], 
                                           image_size=param['image_size'],
                                           batch_size=param['batch_size'], 
                                           world_size=param['world_size'],
                                           partial=param['partial'],
                                           val_split=param['val_split'], 
                                           device=param['device'],
                                           scheduler=param['scheduler'],
                                           warmup_steps=param['warmup_steps'], 
                                           warmup_lr=param['warmup_lr'],
                                           min_lr=param['min_lr'],
                                           lr_decay_rate=param['lr_decay_rate'],
                                           dlip_bert_pretrain=param['dlip_bert_pretrain'],
                                           dlip_bert=param['dlip_bert'], 
                                           dlip_vit=param['dlip_vit'], 
                                           dlip_blip=param['dlip_blip'],
                                           alpha=param['alpha'],
                                           accumulate=param['accumulate'], 
                                           learning_rate=param['learning_rate'],
                                           weight_decay=param['weight_decay'],
                                           dlip_epoch=param['dlip_epoch'],
                                           checkpoint_epoch=param['checkpoint_epoch'],
                                           num_epoch=param['num_epoch'], 
                                           freeze=param['freeze'], 
                                           gradient_clip=param['gradient_clip'],
                                           print=param['print']
                                           )

In [None]:
# Train
if param['multi'] == "True":
    mp.spawn(
        dlip_retrieval_train.train,
        args=(),
        nprocs=dlip_retrieval_train.world_size
    )
else:
    losses_itm, losses_ita, losses_dist = dlip_retrieval_train.train(0)

#### Plot Loss

In [None]:
# Plot Loss
def plot_loss(loss, title):
    plt.figure(figsize=(20, 6))
    plt.plot(loss, label=title)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title(f'{title} per Iteration')
    plt.legend()
    plt.show()

In [None]:
# Plot loss
plot_loss(losses_itm, 'Loss ITM')
plot_loss(losses_ita, 'Loss ITA')
plot_loss(losses_dist, 'Loss Dist')