# Experiment A

In [2]:
!pip install wfdb



In [3]:
import random
import os
import itertools
from glob import glob
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import scipy
import wfdb
import cv2
import zipfile
import ast

from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, accuracy_score, precision_score, recall_score, f1_score

from albumentations.pytorch import ToTensorV2
from transformers import AutoTokenizer, AutoModel
import albumentations as A

from sklearn.model_selection import train_test_split

In [4]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [5]:
!cp -r  "/content/drive/MyDrive/ECG Project (Shared Folder)/2024 - PTB-XL + Reports (Scientific Paper)/PTB-XL.zip" "/content"

In [6]:
!cp -r  "/content/drive/MyDrive/ECG Project (Shared Folder)/2024 - PTB-XL + Reports (Scientific Paper)/ptbxl_database_translatedENG.csv" "/content"

In [7]:
with zipfile.ZipFile('PTB-XL.zip', 'r') as zip_ref:
    zip_ref.extractall('/content')

In [8]:
sys.path.append('/content/drive/MyDrive/ECG Project (Shared Folder)/2024 - PTB-XL + Reports (Scientific Paper)/Experiment A')

In [9]:
import codes

In [10]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Use a fixed seed value
set_seed(42)

In [11]:
def load_raw_data(df, sampling_rate, path):
    data = []
    not_found_files = []
    if sampling_rate == 100:
        filenames = df.filename_lr
    else:
        filenames = df.filename_hr

    for f in filenames:
        try:
            record = wfdb.rdsamp(path+f)
            data.append(record)
        except FileNotFoundError:
            not_found_files.append(f)
            continue

    data = np.array([signal for signal, meta in data])
    return data, not_found_files

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

In [12]:
path = '/content/PTB-XL/'
sampling_rate = 100

In [13]:
import torch

class CONFIG:
    debug = False
    batch_size = 256
    num_workers = 8
    head_lr = 0.001
    image_encoder_lr = 0.001
    patience = 5
    factor = 0.8
    epochs = 20
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Image Model
    model_name = 'resnet18'
    image_embedding_size = 512

    # Text Model
    text_encoder_model = 'emilyalsentzer/Bio_ClinicalBERT'
    text_tokenizer = 'emilyalsentzer/Bio_ClinicalBERT'
    text_embedding_size = 768
    max_length = 200

    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    temperature = 10.0
    optimizer = torch.optim.Adam

    # image size
    size = 224

    # for projection head; used for both image and text encoder
    num_projection_layers = 1
    projection_dim = 512
    dropout = 0.0
    ecg_sr = 128

In [14]:
# load and convert annotation data
Y = pd.read_csv('/content/drive/MyDrive/ECG Project (Shared Folder)/2024 - PTB-XL + Reports (Scientific Paper)/ptbxl_database_translatedENG.csv', index_col='ecg_id')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

In [15]:
Y.head(2)

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709,56,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 9:17:34,sinusrhythmus periphere niederspannung,...,True,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr
2,13243,19,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,True,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr


In [16]:
Y['report_ENG'][0:5]

ecg_id
1       sinus rhythm peripheral low voltage
2    sinus bradycardia otherwise normal ekg
3                   sinus rhythm normal ekg
4                   sinus rhythm normal ekg
5                   sinus rhythm normal ekg
Name: report_ENG, dtype: object

In [17]:
# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

In [None]:
agg_df

In [18]:
# Apply diagnostic superclass
Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

In [19]:
Y.head(5)

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnostic_superclass
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709,56,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 9:17:34,sinusrhythmus periphere niederspannung,...,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,[NORM]
2,13243,19,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,[NORM]
3,20372,37,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,[NORM]
4,17014,24,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,", II,III,AVF",,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,[NORM]
5,17448,19,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,", III,AVR,AVF",,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,[NORM]


In [20]:
# Get the value counts
Y_value_counts = Y['diagnostic_superclass'].value_counts()

# Get the values where the count > 100
records_to_keep = Y_value_counts[Y_value_counts > 100].index

# Filter the DataFrame
Y_filtered = Y[Y['diagnostic_superclass'].isin(records_to_keep)]

In [21]:
print("There are {0} records before filtering.".format(len(Y)))
print("There are {0} records after filtering.".format(len(Y_filtered)))

There are 21799 records before filtering.
There are 21430 records after filtering.


```
Y_filtered['diagnostic_superclass'].value_counts()

[NORM]                 9069
[MI]                   2532
[STTC]                 2400
[CD]                   1708
[MI, CD]               1297
[STTC, HYP]             781
[STTC, MI]              599
[HYP]                   535
[STTC, CD]              471
[]                      411
[NORM, CD]              407
[STTC, HYP, MI]         361
[HYP, CD]               300
[STTC, MI, CD]          223
[STTC, HYP, CD]         211
[HYP, MI]               183
[STTC, HYP, MI, CD]     156
[HYP, MI, CD]           117
Name: diagnostic_superclass, dtype: int64
```



In [22]:
# Filter the DataFrame for test data
Y_filter_test = Y_filtered[Y_filtered['diagnostic_superclass'].apply(lambda x: 'HYP' in x)]

# Filter the DataFrame for train data
Y_filter_train = Y_filtered[~Y_filtered['diagnostic_superclass'].apply(lambda x: 'HYP' in x)]

In [23]:
Y_filter_train['diagnostic_superclass'].value_counts()

[NORM]            9069
[MI]              2532
[STTC]            2400
[CD]              1708
[CD, MI]          1297
[STTC, MI]         599
[STTC, CD]         471
[]                 411
[NORM, CD]         407
[STTC, CD, MI]     223
Name: diagnostic_superclass, dtype: int64

In [24]:
Y_filter_test['diagnostic_superclass'].value_counts()

[STTC, HYP]            781
[HYP]                  535
[STTC, HYP, MI]        361
[CD, HYP]              238
[HYP, MI]              183
[STTC, HYP, CD]        113
[STTC, HYP, CD, MI]    102
Name: diagnostic_superclass, dtype: int64

In [25]:
# Load raw signal data - Training Data
ECG_train = load_raw_data(Y_filter_train, sampling_rate, path)

# Load raw signal data - Testing Data
ECG_test = load_raw_data(Y_filter_test, sampling_rate, path)

In [26]:
ECG_train[0].shape, ECG_test[0].shape

((19117, 1000, 12), (2313, 1000, 12))

In [27]:
print(len(Y_filter_train))
print(len(ECG_train))

19117
2


In [28]:
# Split the training data into training and validation subsets
Y_train, Y_val, ECG_train, ECG_val = train_test_split(Y_filter_train['report_ENG'], ECG_train[0], test_size=0.2, random_state=42)

In [29]:
len(Y_train), len(ECG_train), len(Y_val), len(ECG_val)

(15293, 15293, 3824, 3824)

In [30]:
def load_wsdb(file):
    file = os.path.splitext(file)[0]
    record = wfdb.io.rdrecord(file)
    ecg = record.p_signal.T.astype('float32')
    leads = tuple(record.sig_name)
    sr = record.fs
    ecg[np.isnan(ecg)] = 0.0
    return ecg, leads, sr

In [31]:
import torch
from torch import nn

_ACTIVATION_DICT = {'relu': nn.ReLU,
                    'tanh': nn.Tanh,
                    'none': nn.Identity,
                    'leaky_relu': lambda: nn.LeakyReLU(negative_slope=0.2)}

class Conv1dBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 act='relu', bn=True, dropout=None,
                 maxpool=None, padding=None, stride=1):

        super().__init__()

        if padding is None or padding == 'same':
            padding = kernel_size // 2

        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, bias=not bn, stride=stride)
        self.bn = nn.BatchNorm1d(out_channels) if bn else None
        self.act = _ACTIVATION_DICT[act]()
        self.dropout = None if dropout is None else nn.Dropout(dropout)
        self.maxpool = None if maxpool is None else nn.MaxPool1d(maxpool)

    def forward(self, x):
        x = self.conv(x)

        if self.bn is not None:
            x = self.bn(x)

        x = self.act(x)

        if self.dropout is not None:
            x = self.dropout(x)

        if self.maxpool is not None:
            x = self.maxpool(x)

        return x

class LinearBlock(nn.Module):
    def __init__(self, in_channels, out_channels, act='relu', bn=True, dropout=None):

        super().__init__()

        self.linear = nn.Linear(in_channels, out_channels, bias=not bn)
        self.bn = nn.BatchNorm1d(out_channels) if bn else None
        self.act = _ACTIVATION_DICT[act]()
        self.dropout = None if dropout is None else nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear(x)

        if self.bn is not None:
            x = self.bn(x)

        x = self.act(x)

        if self.dropout is not None:
            x = self.dropout(x)
        return x

class ConvEncoder(nn.Module):
    def __init__(self, in_channels, channels, kernels, bn=True, dropout=None, maxpool=2, padding=0, stride=None):
        super().__init__()



        num_layers = len(channels)
        if stride is None:
            stride = [1] * num_layers

        self.in_layer = Conv1dBlock(in_channels, channels[0], kernels[0], bn=bn, dropout=dropout, maxpool=maxpool, padding=padding, stride=stride[0])

        conv_layers = list()
        for i in range(1, num_layers):
            conv_layers.append(Conv1dBlock(channels[i - 1], channels[i], kernels[i], bn=bn, dropout=dropout, maxpool=maxpool, padding=padding, stride=stride[i]))
        self.conv_layers = nn.ModuleList(conv_layers)

    def forward(self, x):
        x = self.in_layer(x)
        for layer in self.conv_layers:
            x = layer(x)
        return x

class ECGEncoder(nn.Module):
    def __init__(self,
                 window=1000,
                 in_channels=12,
                 channels=(32, 32, 64, 64, 128, 128, 256, 256),
                 kernels=(7, 7, 5, 5, 3, 3, 3, 3),
                 linear=512,
                 output=512):

        super().__init__()

        self.conv_encoder = ConvEncoder(in_channels, channels,  kernels, bn=True)

        with torch.no_grad():
            inpt = torch.zeros((1, in_channels, window), dtype=torch.float32)
            outpt = self.conv_encoder(inpt)
            output_window = outpt.shape[2]

        self.flatten = nn.Flatten()
        self.conv_to_linear = nn.Linear(output_window * channels[-1], linear)
        self.act = nn.ReLU()
        self.out_layer = nn.Linear(linear, output)

    def forward(self, x):
        print(x.shape)
        x = self.conv_encoder(x)
        print(x.shape)
        x = self.flatten(x)
        print(x.shape)
        x = self.conv_to_linear(x)
        x = self.act(x)
        x = self.out_layer(x)
        return x

In [32]:
class ImageEncoder(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = CONFIG
        self.encoder = ECGEncoder(output=CONFIG.image_embedding_size)

    def forward(self, x):
        x = self.encoder(x)
        return x


class TextEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config
        if self.config.pretrained:
            self.model = AutoModel.from_pretrained(self.config.text_encoder_model)
        else:
            self.model = AutoModel.from_config(self.config.text_encoder_model)

        self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_tokenizer)

        for p in self.model.parameters():
            p.requires_grad = False  # Set requires_grad to False for all parameters

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, texts):
        input_ids, attention_mask = self.tokenize_texts(texts)
        embeddinbgs = self.inputs_to_embeddings(input_ids, attention_mask)
        return embeddinbgs

    def tokenize_texts(self, texts):
        inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=self.config.max_length, return_tensors='pt')
        input_ids = inputs['input_ids'].detach().to(self.config.device)
        attention_mask = inputs['attention_mask'].detach().to(self.config.device)
        return input_ids, attention_mask

    def inputs_to_embeddings(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :].detach()


class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CONFIG.projection_dim,
        dropout=CONFIG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CONFIG.temperature,
        image_embedding=CONFIG.image_embedding_size,
        text_embedding=CONFIG.text_embedding_size,
    ):
        super().__init__()
        self.image_encoder = ImageEncoder(CONFIG)
        self.text_encoder = TextEncoder(CONFIG)
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature

    def forward(self, batch):
        image_embeddings = self.image_to_embeddings(batch['image'])
        text_embeddings = self.text_to_embeddings(batch['caption'])

        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax((images_similarity + texts_similarity) / 2 * self.temperature, dim=-1)
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean(), image_embeddings, text_embeddings

    def text_to_embeddings(self, texts):
        text_features = self.text_encoder(texts)
        text_embeddings = self.text_projection(text_features)
        return text_embeddings

    def image_to_embeddings(self, images):
        image_features = self.image_encoder(images)
        image_embeddings = self.image_projection(image_features)
        return image_embeddings


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()


def nxn_cos_sim(A, B, dim=1):
    a_norm = F.normalize(A, p=2, dim=dim)
    b_norm = F.normalize(B, p=2, dim=dim)
    return torch.mm(a_norm, b_norm.transpose(0, 1))

In [33]:
def calc_metrics(image_embeddings, captions, class_embeddings, class_names):
    similarity = nxn_cos_sim(image_embeddings, class_embeddings, dim=1)
    predictions_ids = similarity.argmax(dim=1)
    predictions = [class_names[idx] for idx in predictions_ids]
    tps = [prediction in caption for prediction, caption in zip(predictions, captions)]
    accuracy = np.mean(tps)

    results = dict()
    results['accuracy'] = accuracy

    similarity = similarity.detach().cpu().numpy()
    for i, name in enumerate(class_names):

        true = np.array([name in caption for caption in captions]).astype('int32')

        if true.std() > 0:
            results[f'{name}_rocauc'] = roc_auc_score(true, similarity[:, i])
            results[f'{name}_prauc'] = average_precision_score(true, similarity[:, i])
        else:
            results[f'{name}_rocauc'] = None
            results[f'{name}_prauc'] = None

    return results

def calc_accuracy(image_embeddings, captions, class_embeddings, class_names):
    similarity = nxn_cos_sim(image_embeddings, class_embeddings, dim=1)
    predictions_ids = similarity.argmax(dim=1)
    predictions = [class_names[idx] for idx in predictions_ids]
    tps = [prediction in caption for prediction, caption in zip(predictions, captions)]
    accuracy = np.mean(tps)
    return accuracy

In [34]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

# Dataloaders

In [35]:
from torch.utils.data import Dataset

class ECGDataset(Dataset):
    def __init__(self, ECG_data, labels):
        self.ECG_data = ECG_data
        self.labels = labels.tolist()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        ECG = self.ECG_data[idx]
        label = self.labels[idx]

        return {
            'image': torch.tensor(ECG, dtype=torch.float32),
            'caption': label
        }

In [36]:
train_dataset = ECGDataset(ECG_train, Y_train)
valid_dataset = ECGDataset(ECG_val, Y_val)

In [37]:
valid_dataset[0]

{'image': tensor([[ 0.4140, -0.1390, -0.5520,  ...,  0.5310,  0.1890,  0.2060],
         [ 0.1510, -0.1680, -0.3200,  ...,  0.1370, -0.1900, -0.0660],
         [-0.2320, -0.1550,  0.0770,  ...,  0.2920, -0.2570, -0.1710],
         ...,
         [-0.0150, -0.0630, -0.0480,  ..., -0.4830, -0.0170, -0.0360],
         [-0.0260, -0.0710, -0.0440,  ..., -0.5030, -0.0280, -0.0470],
         [-0.0310, -0.0780, -0.0470,  ..., -0.5100, -0.0360, -0.0610]]),
 'caption': 'sinus rhythm. right bundle branch block.'}

In [38]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False, pin_memory=True)

## Training and Validation

In [43]:
def train_epoch(model, loader, optimizer, classes):
    tqdm_object = tqdm(loader, total=len(loader))
    loss_meter = AvgMeter()
    accuracy_meter = AvgMeter()
    for batch in tqdm_object:
        model.train()
        batch = {k: v.to(CONFIG.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        loss, image_embeddings, text_embeddings = model(batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        model.eval()
        with torch.no_grad():
            class_embeddings = model.text_to_embeddings(classes)

        accuracy = calc_accuracy(image_embeddings, batch['caption'], class_embeddings, classes)

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)
        accuracy_meter.update(accuracy, count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg, train_accuracy=accuracy_meter.avg)

    return loss_meter, accuracy_meter


def valid_epoch(model, loader, classes):
    model.eval()

    with torch.no_grad():
        class_embeddings = model.text_to_embeddings(classes).detach().cpu()

    tqdm_object = tqdm(loader, total=len(loader))
    embeddings = list()
    captions = list()
    with torch.no_grad():
        for batch in tqdm_object:
            batch = {k: v.to(CONFIG.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            loss, image_embeddings, text_embeddings = model(batch)
            embeddings.append(image_embeddings.cpu())
            captions += batch['caption']

    embeddings = torch.cat(embeddings)

#     plt.figure(figsize=(30, 5))
#     plt.hist(class_embeddings.numpy().flatten(), bins=100)
#     plt.grid()
#     plt.show()


#     plt.figure(figsize=(30, 5))
#     plt.hist(embeddings.numpy().flatten(), bins=100)
#     plt.grid()
#     plt.show()


    metric = calc_metrics(embeddings, captions, class_embeddings, classes)
    return metric

In [39]:
model = CLIPModel().to(CONFIG.device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

In [40]:
params = [
    {"params": model.image_encoder.parameters(), "lr": CONFIG.image_encoder_lr},
    {"params": model.image_projection.parameters(), "lr": CONFIG.head_lr},
    {"params": model.text_projection.parameters(), "lr": CONFIG.head_lr},
]

optimizer = CONFIG.optimizer(params)

In [41]:
def train_epoch(model, loader, optimizer, classes):
    tqdm_object = tqdm(loader, total=len(loader))
    loss_meter = AvgMeter()
    accuracy_meter = AvgMeter()
    for batch in tqdm_object:
        model.train()
        batch = {k: v.to(CONFIG.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        batch['image'] = batch['image'].transpose(1, 2)  # transpose the last two dimensions
        print(batch['image'].shape)
        loss, image_embeddings, text_embeddings = model(batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            class_embeddings = model.text_to_embeddings(classes)

        accuracy = calc_accuracy(image_embeddings, batch['caption'], class_embeddings, classes)

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)
        accuracy_meter.update(accuracy, count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg, train_accuracy=accuracy_meter.avg)

    return loss_meter, accuracy_meter

In [42]:
# Assuming 'classes' is a list of all unique labels in dataset
Y_train_classes = list(set(Y_train))

loss_meter, accuracy_meter = train_epoch(model, train_loader, optimizer, Y_train_classes)

  0%|          | 0/478 [00:00<?, ?it/s]

torch.Size([32, 12, 1000])
torch.Size([32, 12, 1000])
torch.Size([32, 256, 1])
torch.Size([32, 256])


  0%|          | 0/478 [00:04<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.52 GiB. GPU 0 has a total capacity of 15.77 GiB of which 800.38 MiB is free. Process 4924 has 14.99 GiB memory in use. Of the allocated memory 14.54 GiB is allocated by PyTorch, and 57.63 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)