# Multimodal Classification

In [None]:
#@title Load libraries
# libraries for the files in google drive
import torch
import torch.nn as nn
import torchvision.models as models
from pydrive.auth import GoogleAuth
from google.colab import drive
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import copy
import re
import pandas as pd
from io import StringIO
import cv2
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import pyplot as plt, image
import random
import numpy as np
from torch.utils.data import DataLoader
from sklearn.preprocessing import MultiLabelBinarizer
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from collections import OrderedDict
import warnings
import time
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.metrics import classification_report
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import nltk
from nltk.tokenize import word_tokenize
nltk.download('stopwords')
from nltk.corpus import stopwords
nltk.download('wordnet')
from nltk.stem import WordNetLemmatizer
nltk.download('punkt')
from sklearn.feature_extraction.text import TfidfVectorizer
import torch.utils.data
import csv
import torchvision
warnings.filterwarnings("ignore")

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
#@title Download content from google drive
# Download the dataset
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

file_id = "1gMEr3dSoacNkSJlYnd8btBHtxEbGA6Ao"
download = drive.CreateFile({'id': file_id})

# Download the file to a local disc
download.GetContentFile('multi-label-classification-competition-2023.zip')

# unzip the download file
!unzip -qq multi-label-classification-competition-2023.zip

print("The download process for dataset is completed")

The download process for dataset is completed


In [None]:
#@title Utilities, models, train and validate functions
###############################################################################
# Utilities

class DatasetArray(Dataset):
    r"""This is a child class of the pytorch Dataset object."""
    def __init__(self, data, labels, captions, train = True):
        # distinguish train and test dataset
        if labels is not None:
            self.label_arr = np.asarray(labels)
        else:
            self.label_arr = None
        self.data_arr = np.asarray(data)
        self.caption_arr = np.asarray(captions)
        self.train = train
        # Pre-process image for efficient net b4
        if self.train:
          self.transform = transforms.Compose([
              transforms.ToPILImage(),
              transforms.Resize(size = (320,320), interpolation = 3),
              transforms.CenterCrop(size = (300,300)),
              transforms.ToTensor(),
              transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
              ])
        else:
          self.transform = transforms.Compose([
              transforms.ToPILImage(),
              transforms.Resize(size = (320,320), interpolation = 3),
              transforms.CenterCrop(size = (300,300)),
              transforms.ToTensor(),
              transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
              ])
        
    def __len__(self):
        return len(self.data_arr)
    
    def __getitem__(self, index):
        data = self.data_arr[index]
        data = self.transform(data)
        caption = self.caption_arr[index]

        # if training and validation data
        if self.label_arr is not None:
            label = self.label_arr[index]
        
            return (torch.tensor(data, dtype=torch.float32), 
                    torch.tensor(label, dtype=torch.float32), 
                    torch.tensor(caption, dtype=torch.float32), index)
        # testing data
        else:
            return (torch.tensor(data, dtype=torch.float32),  
                    torch.tensor(caption, dtype=torch.float32), index)

# Prepare dataloader
def get_loader(batch_size =128, num_workers = 1, train=True, shuffle=True, sampling=False,
               data=None, labels=None, captions=None):
    # check labels
    if labels is not None:
        labels, class_name = one_hot_encoding(labels)
    else:
        class_name = None
    data = DatasetArray(data = data, labels = labels, captions=captions, train = train)
    loader = DataLoader(
        data,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )

    return loader

# Load images, captions and labels
def load_training_data():
  # Get training data dataframe
  train_df = read_csv_file("train.csv", train = True)
  # Get training images
  train_data = get_images(train_df["ImageID"])
  # One hot coding for training label
  global class_name
  train_label, class_name = one_hot_encoding(train_df["Labels"])
  train_label = train_df['Labels']
  train_caption = proprocess_caption(train_df['Caption'])

  return train_data, train_label, train_caption

# Load images, captions
def load_testing_data():
  # Get testing data dataframe
  test_df = read_csv_file("test.csv", train = False)
  # Get testing images
  test_data = get_images(test_df["ImageID"])
  # One hot coding for training label
  test_caption = proprocess_test_caption(test_df['Caption'])

  return test_data, test_caption

# Reading the csv
def read_csv_file(filename, train = True):
  csv_list = []
  with open("COMP5329S1A2Dataset/"+filename) as file:
    for line in file.readlines():
      if not re.match(r'^\d+.jpg', line):
        continue
      ImageID = line.split(",")[0]
      Caption = line.split(",")[-1]
      if train:
        Labels = line.split(",")[1]
        Labels = [int(i) for i in Labels.split(' ')]
        csv_list.append({"ImageID":ImageID, "Labels": Labels, "Caption": Caption})
      else:
        csv_list.append({"ImageID":ImageID, "Caption": Caption})
    csv_df = pd.DataFrame(csv_list)

    return csv_df

# imbalance ratio per label
def IRLbl(labels):
    N, C = labels.shape
    pos_nums_per_label = np.sum(labels, axis=0)
    max_pos_nums = np.max(pos_nums_per_label)
    return max_pos_nums / pos_nums_per_label

def MeanIR(labels):
    IRLbl_VALUE = IRLbl(labels)
    return np.mean(IRLbl_VALUE)

def ML_ROS(all_labels, indices=None, num_samples=None, Preset_MeanIR_value=0,
                 max_clone_percentage=1000, sample_size=3200):

    indices = list(range(len(all_labels))) \
        if indices is None else indices

    # if num_samples is not provided,
    # draw `len(indices)` samples in each iteration
    num_samples = len(indices) \
        if num_samples is None else num_samples

    MeanIR_value = MeanIR(all_labels) if Preset_MeanIR_value == 0 else Preset_MeanIR_value
    IRLbl_value = IRLbl(all_labels)
    # N is the number of samples, C is the number of labels
    N, C = all_labels.shape
    # the samples index of every class
    indices_per_class = {}
    minority_classes = []
    # accroding to psedu code, maxSamplesToClone is the upper limit of the number of samples can be copied from original dataset
    maxSamplesToClone = N / 100 * max_clone_percentage
    print('Max Clone Limit:', maxSamplesToClone)
    for i in range(C):
        ids = all_labels[:, i] == 1
        # How many samples are there for each label
        indices_per_class[i] = [ii for ii, x in enumerate(ids) if x]
        if IRLbl_value[i] > MeanIR_value:
            minority_classes.append(i)

    new_all_labels = all_labels
    oversampled_ids = []
    minorNum = len(minority_classes)
    print(minorNum, 'minor classes.')

    for idx, i in enumerate(minority_classes):
        tid = time.time()
        while True:
            pick_id = list(np.random.choice(indices_per_class[i], sample_size))
            indices_per_class[i].extend(pick_id)
            # recalculate the IRLbl_value
            # The original label matrix (New_ all_ Labels) and randomly selected label matrix (all_ labels[pick_ ID) and recalculate the irlbl
            new_all_labels = np.concatenate([new_all_labels, all_labels[pick_id]], axis=0)
            oversampled_ids.extend(pick_id)

            newIrlbl = IRLbl(new_all_labels)
            if newIrlbl[i] <= MeanIR_value:
                print('\nMeanIR satisfied.', newIrlbl[i])
                break
            if len(oversampled_ids) >= maxSamplesToClone:
                print('\nExceed max clone.', len(oversampled_ids))
                break
            # if IRLbl(new_all_labels)[i] <= MeanIR_value or len(oversampled_ids) >= maxSamplesToClone:
            #     break
            print("\roversample length:{}".format(len(oversampled_ids)), end='')
        print('Processed the %d/%d minor class:' % (idx+1, minorNum), i, time.time()-tid, 's')
        if len(oversampled_ids) >= maxSamplesToClone:
            print('Exceed max clone. Exit', len(oversampled_ids))
            break
    return new_all_labels, oversampled_ids

# Prepare oversampled dataset with given ids
def oversample_dataset(data,labels,captions,ids):
  new_data = []
  new_captions = []
  new_labels = []
  for id in ids:
    new_data.append(data[id])
    new_captions.append(captions[id])
    new_labels.append(labels[id])

  return new_data,new_labels,new_captions

# Embedding preprocessing
def remove_punctuation_re(x):
  x = re.sub(r'[^\w\s]','',x)  
  return x

def tokenize_lower(x):
  x = x.lower()
  return word_tokenize(x)

def remove_stopwords(x):
  stop_words = set(stopwords.words('english'))
  x = [word for word in x if word not in stop_words]
  return x

def lemmatize(x):
  lemmatizer = WordNetLemmatizer()
  x = [lemmatizer.lemmatize(word) for word in x]
  return x

def proprocess_caption(caption):
  caption = caption.apply(remove_punctuation_re)
  caption = caption.apply(tokenize_lower) 
  caption = caption.apply(remove_stopwords) 
  caption = caption.apply(lemmatize)
  # count the word frequency
  corpus = {}
  for cap in caption:
    for word in cap:
      if word not in corpus:
        corpus[word] = 1
      else:
        corpus[word] += 1

  # set a threshold for the word frequency
  threshold = 5
  # set the vocab to global and use in test set
  global vocab
  vocab = [i for i in corpus if corpus[i] >= threshold]
  caption = caption.apply(lambda x: [word for word in x if word in vocab])
  caption = caption.apply(' '.join)
  # set the vectorizer to global and use the same vectorizer in test set
  global vectorizer
  vectorizer = TfidfVectorizer()
  caption = vectorizer.fit_transform(caption)
  global text_input
  text_input = caption.shape[1]
  return caption

def proprocess_test_caption(caption):
  caption = caption.apply(remove_punctuation_re)
  caption = caption.apply(tokenize_lower) 
  caption = caption.apply(remove_stopwords) 
  caption = caption.apply(lemmatize)
  caption = caption.apply(lambda x: [word for word in x if word in vocab])
  caption = caption.apply(' '.join)
  # don't fit again
  caption = vectorizer.transform(caption)
  
  return caption

# One hot encoding
def one_hot_encoding(labels):
    mlb = MultiLabelBinarizer()
    return mlb.fit_transform(labels), mlb.classes_

# Get images
def get_images(image_id_list):
  img_path = "COMP5329S1A2Dataset/data"
  #max_height, max_width = get_image_dim()
  images_list = []
  for image_id in image_id_list:
    img = image.imread(os.path.join(img_path, image_id))
    #img = image_padding(img, max_height, max_width)
    images_list.append(img)
  return images_list

# Get the list of class name for classification report
def get_class_list():
  global class_list
  class_list = []
  for i in class_name:
    class_list.append("class"+str(i))
  return class_list

def calculate_weighted_score(y_true, y_pred):
    y_pred = np.round(y_pred)
    f1 = f1_score(y_true, y_pred, average='micro')
    precision = precision_score(y_true, y_pred, average='micro')
    recall = recall_score(y_true, y_pred, average='micro')
    return f1, precision, recall

# split the data into train and validation sets
def train_val_random_split(data, labels, captions, fracs=[0.8, 0.2]):
    assert len(fracs) == 2
    assert sum(fracs) == 1
    assert all(frac > 0 for frac in fracs)
    n = len(data)
    subset_lens = [int(n*frac) for frac in fracs]
    idxs = list(range(n))
    random.shuffle(idxs)
    new_data = []
    new_labels = []
    new_captions = []
    start_idx = 0
    for subset_len in subset_lens:
        end_idx = start_idx + subset_len
        cur_idxs = idxs[start_idx:end_idx]
        selected_data = [data[i] for i in cur_idxs]
        new_data.append(selected_data)
        selected_labels = [labels[i] for i in cur_idxs]
        new_labels.append(selected_labels)
        selected_captions = [captions[i] for i in cur_idxs]
        new_captions.append(selected_captions)
        start_idx = end_idx

    return new_data, new_labels, new_captions

class EarlyStopping():
  def __init__(self,  patience=10):
    self.patience = patience
    self.best_model = None
    self.best_f1 = None
    self.counter = 0
    
  def __call__(self, model, val_f1):
    # First epoch 
    if self.best_f1 == None:
      self.best_f1 = val_f1
      self.best_model = copy.deepcopy(model)
    # F1-score is improving
    elif self.best_f1 < val_f1:
      self.best_f1 = val_f1
      self.counter = 0
      self.best_model.load_state_dict(model.state_dict())
    # F1-score is not improving
    elif self.best_f1 >= val_f1:
      self.counter += 1
      # If patience over the predefined limit
      if self.counter >= self.patience:
        # Restore the weight from the best model
        model.load_state_dict(self.best_model.state_dict())
        return True
    return False

# Focal loss
def focal_bce(outputs, targets, gamma=2):
    l = outputs.reshape(-1)
    t = targets.reshape(-1)
    p = torch.sigmoid(l)
    p = torch.where(t >= 0.5, p, 1-p)
    logp = - torch.log(torch.clamp(p, 1e-4, 1-1e-4))
    loss = logp*((1-p)**gamma)
    loss = 18*loss.mean()
    return loss

###############################################################################
# Models
###############################################################################

# EfficientNet-b4
def eff_b4(pretrained=True, fine_tune=True):
    if pretrained:
        print('[INFO]: Loading pre-trained weights')
    else:
        print('[INFO]: Not loading pre-trained weights')
    model = models.efficientnet_b4(pretrained=pretrained)
    if fine_tune:
        print('[INFO]: Fine-tuning all layers...')
        for params in model.parameters():
            params.requires_grad = True
    elif not fine_tune:
        print('[INFO]: Freezing hidden layers...')
        for params in model.parameters():
            params.requires_grad = False

    model.classifier[1] = nn.Sequential(
        nn.BatchNorm1d(num_features=1792),
        nn.Linear(1792, 1024),
        nn.ReLU(),
        nn.BatchNorm1d(1024),
        nn.Dropout(0.4),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.BatchNorm1d(512),
        nn.Linear(512, 128),
        nn.ReLU(),
        nn.BatchNorm1d(num_features=128),
        nn.Dropout(0.4),
        nn.Linear(128, 18))
    return model

# text model
class text_branch(nn.Module):
    def __init__(self, input_size, dropout_rate=0.5):
        super(text_branch, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 18),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Combined model
class image_text_net(nn.Module):
  def __init__(self,image_model,text_model,pretrained=True):
    super(image_text_net,self).__init__()
    self.image_model = image_model.to(device)
    self.text_model = text_model.to(device)
    self.model = nn.Sequential(
            nn.ReLU(),
            nn.Linear(36,324),
            nn.ReLU(),
            nn.Dropout(0.2),  
            nn.Linear(324,108),
            nn.ReLU(),
            nn.Linear(108,18)
        )

    for param in self.image_model.parameters(): 
      param.requires_grad = False
    for param in self.text_model.parameters(): 
      param.requires_grad = False

  def forward(self,x,y):
    self.image_model.eval()
    self.text_model.eval()
    with torch.no_grad():
      output_image = self.image_model(x)
      output_text = self.text_model(y)   
    output = torch.cat((output_image,output_text),1)
    output = self.model(output)

    return output

# Train funciton for image model
def train(epoch, model, optimizer, criterion, train_loader, threshold = 0.5):
    model.train()
    counter = 0
    train_running_loss = 0.0
    train_running_f1 = 0.0
    target_list = []
    pred_list = []
    for step, (data, targets, captions, indices) in enumerate(train_loader):
      counter += 1
      data = data.to(device)
      targets = targets.to(device)
      optimizer.zero_grad()
      outputs = model(data)
      # if using focal loss, the sigmoid needs to be commented out
      outputs = torch.sigmoid(outputs)
      loss = criterion(outputs, targets)
      # loss = focal_bce(outputs, targets)
      train_running_loss += loss.item()
      train_running_f1 += f1_score(targets.cpu().detach().numpy(), (outputs>threshold).cpu().detach().numpy(), average='micro')
      loss.backward()
      optimizer.step()
        
    train_loss = train_running_loss / counter
    train_f1 = train_running_f1 / counter

    return train_loss, train_f1

# Validate function for image model
def validate(epoch, model, criterion, val_loader, is_test=False, threshold = 0.5):
    model.eval()
    counter = 0
    val_running_loss = 0.0
    val_running_f1 = 0.0
    target_list = []
    pred_list = []
    with torch.no_grad():
        for step, (data, targets, captions, indices) in enumerate(val_loader):
            counter += 1
            # prepare min_batch
            data = data.to(device)
            targets = targets.to(device)
            outputs = model(data)
            # apply sigmoid activation to get all the outputs between 0 and 1
            # if using focal loss, the sigmoid needs to be commented out
            outputs = torch.sigmoid(outputs)
            loss = criterion(outputs, targets)
            # loss = focal_bce(outputs, targets)
            val_running_loss += loss.item()
            val_running_f1 += f1_score(targets.cpu().detach().numpy(), (outputs>threshold).cpu().detach().numpy(), average='micro')
            # get predicted and target labels for the mini-batch
            pred_labels = (outputs>0.5).cpu().detach().numpy()
            target_labels = targets.cpu().detach().numpy()
            pred_list.append(pred_labels)
            target_list.append(target_labels)
            
        val_loss = val_running_loss / counter
        val_f1 = val_running_f1 / counter
        # compute classification report
        pred_list = np.concatenate(pred_list)
        target_list = np.concatenate(target_list)
        report = classification_report(target_list, pred_list, target_names= class_list)
        print(report)
        
        return val_loss, val_f1

# Train and validate for text model
def train_validate_text(epochs, model, optimizer, criterion, train_loader, val_loader):
  patience = 5
  best_val_f1 = 0
  counter = 0
  class_names = [f'Class {i}' for i in range(1, 20) if i != 12]
  for epoch in range(epochs):
    model.train()
    train_loss = 0
    train_running_f1 = 0
    val_running_f1 = 0
    all_outputs = []
    all_labels = []
    for step, (data, targets, captions, indices) in enumerate(train_loader):
      targets = targets.to(device)
      captions = captions.to(device)
      outputs = model(captions)
      loss = criterion(outputs, targets)   
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      train_loss += loss.item()
      train_running_f1 += f1_score(targets.cpu().detach().numpy(), (outputs>0.5).cpu().detach().numpy(), average='micro')
    train_loss /= len(train_loader)
    train_f1 = train_running_f1 / len(train_loader)
    model.eval()
    val_loss = 0
    with torch.no_grad():
      for step, (data, targets, captions, indices) in enumerate(val_loader): 
        targets = targets.to(device)
        captions = captions.to(device)
        outputs = model(captions)
        loss = criterion(outputs, targets)   
        val_loss += loss.item()
        val_running_f1 += f1_score(targets.cpu().detach().numpy(), (outputs>0.5).cpu().detach().numpy(), average='micro')
        all_outputs.extend(outputs.cpu().numpy())
        all_labels.extend(targets.cpu().numpy())
    validation_loss = val_loss / len(val_loader)
    val_f1 = val_running_f1 / len(val_loader)
    y_pred_np = np.round(np.array(all_outputs))  
    report = classification_report(np.array(all_labels), y_pred_np, target_names=class_names)
    print(report)
    print(f'Epoch {epoch+1}/{epochs}, Training F1: {train_f1:.4f}, Validation F1:{val_f1:.4f}, Validation loss:{validation_loss:.4f}')
    if val_f1 > best_val_f1:
      print('best model saved!')
      counter = 0
      best_val_f1 = val_f1
      torch.save(model.state_dict(), 'text_model.pt')
    else:
      counter += 1
      if counter >= patience:
        print('Early stopping!')
        break

# Train funciton for combined model
def train_comb(epoch, model, optimizer, criterion, train_loader, threshold = 0.5):
    model.train()
    counter = 0
    train_running_loss = 0.0
    train_running_f1 = 0.0
    target_list = []
    pred_list = []
    for step, (data, targets, captions, indices) in enumerate(train_loader):
      counter += 1
      data = data.to(device)
      targets = targets.to(device)
      captions = captions.to(device)
      optimizer.zero_grad()
      outputs = model(data,captions)
      outputs = torch.sigmoid(outputs)
      loss = criterion(outputs, targets)
      train_running_loss += loss.item()
      train_running_f1 += f1_score(targets.cpu().detach().numpy(), (outputs>threshold).cpu().detach().numpy(), average='micro')
      loss.backward()
      optimizer.step()
        
    train_loss = train_running_loss / counter
    train_f1 = train_running_f1 / counter

    return train_loss, train_f1

# Validate function for combined model
def validate_comb(epoch, model, criterion, val_loader, is_test=False, threshold = 0.5):
    model.eval()
    counter = 0
    val_running_loss = 0.0
    val_running_f1 = 0.0
    target_list = []
    pred_list = []
    with torch.no_grad():
        for step, (data, targets, captions, indices) in enumerate(val_loader):  
            counter += 1
            data = data.to(device)
            targets = targets.to(device)
            captions = captions.to(device)
            outputs = model(data,captions)
            outputs = torch.sigmoid(outputs)
            loss = criterion(outputs, targets)
            val_running_loss += loss.item()
            val_running_f1 += f1_score(targets.cpu().detach().numpy(), (outputs>threshold).cpu().detach().numpy(), average='micro')
            pred_labels = (outputs>0.5).cpu().detach().numpy()
            target_labels = targets.cpu().detach().numpy()
            pred_list.append(pred_labels)
            target_list.append(target_labels)
            
        val_loss = val_running_loss / counter
        val_f1 = val_running_f1 / counter
        pred_list = np.concatenate(pred_list)
        target_list = np.concatenate(target_list)
        report = classification_report(target_list, pred_list, target_names= class_list)
        print(report)

    return val_loss, val_f1

In [None]:
#@title Load data and prepare loaders
# Check gpu
global device
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
# Get data
data, labels, captions = load_training_data()
test_data, test_captions = load_testing_data()

# Split training and validation set in training set
[train_data, val_data],[train_labels, val_labels],[train_captions, val_captions] = train_val_random_split(data, labels, captions.toarray(), fracs = [0.9,0.1])

# Keep the same data type for testing set
test_captions = [np.squeeze(i.toarray()) for i in test_captions]

# Prepare training and validation loader
val_loader = get_loader(batch_size = 100, num_workers = 1,
                          train = False, shuffle = False,
                          data=val_data, labels=val_labels, captions=val_captions)
train_loader = get_loader(batch_size = 32, num_workers = 1,
                          data=train_data, labels=train_labels, captions=train_captions)
test_loader = get_loader(batch_size = 100, num_workers = 1,
                          train = False, shuffle = False,
                          data=test_data, labels=None, captions=test_captions)

In [None]:
#@title Train the image model
# To train the image model
torch.cuda.empty_cache()
# Efficient net B4
model = eff_b4(pretrained=True, fine_tune=False).to(device)
# Criterion
criterion = nn.BCELoss()
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr= 0.0001)
class_list = get_class_list()
# initialize early_stopping (Change patience if you need to)
early_stopping = EarlyStopping(patience = 10)
best_f1 = 0
epochs = 1
# Train and validate the cnn model
for epoch in range(1, epochs + 1):
  # Train
  train_loss, train_f1 = train(epoch, model, optimizer, criterion, train_loader)
  #validation
  val_loss, val_f1 = validate(epoch, model, criterion, val_loader)
  print(f"[Epoch {epoch}/{epochs}] Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}, Validation Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}")
  if val_f1 > best_f1:
    best_f1 = val_f1
    # torch.save(model.state_dict(), "effb3_oversample_model.pth")
    torch.save(model.state_dict(), "effb4_model.pt")
  # Check early stopping  
  if early_stopping(model, val_f1):
    print(f"Early stopping at {epoch} epoch!")
    break

[INFO]: Loading pre-trained weights


Downloading: "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b4_rwightman-7eb33cd5.pth
100%|██████████| 74.5M/74.5M [00:04<00:00, 17.2MB/s]


[INFO]: Freezing hidden layers...
              precision    recall  f1-score   support

      class1       0.97      0.80      0.87      2293
      class2       0.24      0.36      0.29       111
      class3       0.43      0.69      0.53       426
      class4       0.22      0.92      0.35       142
      class5       0.15      0.97      0.26       110
      class6       0.19      0.85      0.31       120
      class7       0.26      0.92      0.41       108
      class8       0.32      0.71      0.44       215
      class9       0.18      0.77      0.29        91
     class10       0.19      0.80      0.30       119
     class11       0.10      0.23      0.14        65
     class13       0.27      0.59      0.37        66
     class14       0.05      0.23      0.08        31
     class15       0.19      0.44      0.27       199
     class16       0.23      0.51      0.32       108
     class17       0.44      0.89      0.59       144
     class18       0.22      0.50      0.30    

In [None]:
#@title Train the text model
# To train the text model
# run train val on text model
class_list = get_class_list()
# Efficient net B3
model = text_branch(input_size=text_input).to(device)
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, amsgrad=True)
# Criterion
criterion = nn.BCELoss()
train_validate_text(1, model, optimizer, criterion, train_loader, val_loader)

              precision    recall  f1-score   support

     Class 1       0.76      1.00      0.87      2293
     Class 2       0.00      0.00      0.00       111
     Class 3       0.00      0.00      0.00       426
     Class 4       0.00      0.00      0.00       142
     Class 5       0.00      0.00      0.00       110
     Class 6       0.00      0.00      0.00       120
     Class 7       0.00      0.00      0.00       108
     Class 8       0.00      0.00      0.00       215
     Class 9       0.00      0.00      0.00        91
    Class 10       0.00      0.00      0.00       119
    Class 11       0.00      0.00      0.00        65
    Class 13       0.00      0.00      0.00        66
    Class 14       0.00      0.00      0.00        31
    Class 15       0.00      0.00      0.00       199
    Class 16       0.00      0.00      0.00       108
    Class 17       0.00      0.00      0.00       144
    Class 18       0.00      0.00      0.00       157
    Class 19       0.00    

In [None]:
#@title Train the combined model
# Train the combined model
torch.cuda.empty_cache()

image_model = eff_b4(pretrained=True, fine_tune=False).to(device)
text_model = text_branch(input_size=text_input).to(device)
# Load the pre-trained model
image_model.load_state_dict(torch.load('effb4_model.pt'))
text_model.load_state_dict(torch.load('text_model.pt'))

model = image_text_net(image_model,text_model).to(device)
criterion = nn.BCELoss()

optimizer = torch.optim.Adam(model.parameters(), lr= 0.0001)
early_stopping = EarlyStopping(patience = 10)
best_f1 = 0
epochs = 1
# Train and validate the cnn model
for epoch in range(1, epochs + 1):
  # Train
  train_loss, train_f1 = train_comb(epoch, model, optimizer, criterion, train_loader)
  #validation
  val_loss, val_f1 = validate_comb(epoch, model, criterion, val_loader)
  print(f"[Epoch {epoch}/{epochs}] Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}, Validation Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}")
  if val_f1 > best_f1:
    best_f1 = val_f1
    torch.save(model.state_dict(), "model.pt")
  # Check early stopping  
  if early_stopping(model, val_f1):
    print(f"Early stopping at {epoch} epoch!")
    break

[INFO]: Loading pre-trained weights
[INFO]: Freezing hidden layers...
              precision    recall  f1-score   support

      class1       0.83      0.97      0.90      2293
      class2       0.00      0.00      0.00       111
      class3       0.66      0.29      0.41       426
      class4       0.00      0.00      0.00       142
      class5       0.99      0.68      0.81       110
      class6       0.00      0.00      0.00       120
      class7       1.00      0.05      0.09       108
      class8       0.25      0.00      0.01       215
      class9       0.00      0.00      0.00        91
     class10       0.00      0.00      0.00       119
     class11       0.00      0.00      0.00        65
     class13       0.00      0.00      0.00        66
     class14       0.00      0.00      0.00        31
     class15       0.00      0.00      0.00       199
     class16       0.00      0.00      0.00       108
     class17       1.00      0.33      0.49       144
     class1