In [None]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

# Import session

In [None]:
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils import resample
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import imgaug.augmenters as iaa
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import tqdm
import tqdm.notebook as tqdm
from torch.nn.utils import clip_grad_norm_
import time
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import textwrap
import pickle
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss, classification_report


# Prepare Data

In [None]:
train_df = pd.read_csv('./data/allergens_train.csv')
val_df = pd.read_csv('./data/allergens_val.csv')
test_df = pd.read_csv('./data/allergens_test.csv')
columns_to_drop = ['target', 'class_id']
train_df = train_df.drop(columns_to_drop, axis=1)
val_df = val_df.drop(columns_to_drop, axis=1)
test_df = test_df.drop(columns_to_drop, axis=1)
allergens = train_df.columns[4:]
train_df

In [None]:
mapped_data = train_df[allergens].apply(lambda x: x.map(lambda val: 1 if val > 0 else 0)).values
mapped_data = mapped_data.T
co_all = np.corrcoef(mapped_data)

fig, ax1 = plt.subplots(1, 1, figsize=(30, 25))
heatmap = sns.heatmap(co_all, annot=True, fmt='2.1%', ax=ax1, cmap='RdBu', vmin=-1, vmax=1)

ax1.set_xticklabels(allergens, rotation=45, ha='right')
ax1.set_yticklabels(allergens, rotation=0, ha='right')
plt.show()


# Preprocessing data

In [None]:
image_size = (224, 224)

class CustomDataset(Dataset):
    def __init__(self, dataframe, transform=None, augmentation=None):
        self.dataframe = dataframe
        self.transform = transform
        self.augmentation = augmentation

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx, 0]
        image = Image.open(img_name)
        label = torch.Tensor(self.dataframe.iloc[idx, 4:].values.astype(float))

        if image.mode != "RGB":
            image = image.convert("RGB")

        if self.augmentation:
            image_np = np.array(image)
            augmented_image_np = self.augmentation(images=[image_np])[0]
            image = Image.fromarray(augmented_image_np)

        if self.transform:
            image = self.transform(image)

        image = torch.Tensor(image)
        label = torch.Tensor(label)

        return image, label

augmentation = iaa.Sequential([
    iaa.Fliplr(0.5),
    iaa.Crop(percent=(0, 0.2)),
    iaa.Sometimes(0.5, iaa.GaussianBlur(sigma=(0, 0.5))),
    iaa.ContrastNormalization((0.75, 1.5)),
    iaa.AdditiveGaussianNoise(loc=0, scale=(0.2, 0.05 * 255), per_channel=0.5),
    iaa.Multiply((0.8, 1.2), per_channel=0.2),
    iaa.Affine(rotate=(-30, 30)),
])

train_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.ColorJitter(brightness=(0.7, 1.3)),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), shear=0.01, scale=(0.9, 1.25))
])

test_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

train_dataset = CustomDataset(dataframe=train_df, transform=train_transform, augmentation=augmentation)
val_dataset = CustomDataset(dataframe=val_df, transform=test_transform)
test_dataset = CustomDataset(dataframe=test_df, transform=test_transform)

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


In [None]:
def show_augmentations(dataset, original_idx=0, num_augmentations=5):
    original_img, _ = dataset[original_idx]

    fig, axes = plt.subplots(1, num_augmentations + 1, figsize=(15, 3))

    # Display the original image
    axes[0].imshow(np.transpose(original_img.numpy(), (1, 2, 0)))
    axes[0].axis('off')
    axes[0].set_title('Original')

    # Display augmentations of the same image
    for i in range(1, num_augmentations + 1):
        augmented_img, _ = dataset[original_idx]
        axes[i].imshow(np.transpose(augmented_img.numpy(), (1, 2, 0)))
        axes[i].axis('off')
        axes[i].set_title(f'Augmentation {i}')

    plt.show()
for i in range(10):
    show_augmentations(train_dataset, original_idx=i, num_augmentations=5)


In [None]:
len(allergens)

# Train Model

In [None]:
class FineTunedResNet(nn.Module):
    def __init__(self, num_classes=len(allergens)):
        super(FineTunedResNet, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.fc = nn.Linear(resnet.fc.in_features, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

def Model():
    return FineTunedResNet()

model = Model()
model.to(mps_device)

In [None]:
def train_loop(train_loader, criterion, optimizer, train_loss, correct_train, total_train):
    for images, targets in tqdm.tqdm(train_loader, desc='Training'):
        images, targets = images.to(mps_device), targets.to(mps_device)
        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, targets)
        loss.backward()
        clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        train_loss += loss.item()

        total_batch = (targets.size(0) * targets.size(1))
        total_train += total_batch

        predicted_label = torch.sigmoid(outputs)>=0.5
        target_data = (targets==1.0)
        correct = torch.sum((predicted_label == target_data.to(mps_device)).to(torch.float)).item()
        correct_train += correct
    return train_loss, correct_train, total_train

def evaluate(val_loader, criterion, optimizer, val_loss, correct_val, total_val):
    model.eval()
    with torch.no_grad():
        for images, targets in val_loader:
            images, targets = images.to(mps_device), targets.to(mps_device)
            outputs = model(images)
            loss = criterion(outputs, targets)
            val_loss += loss.item()

            total_batch = (targets.size(0) * targets.size(1))
            total_val += total_batch

            predicted_label = torch.sigmoid(outputs)>=0.5
            target_data = (targets==1.0)
            correct = torch.sum((predicted_label == target_data.to(mps_device)).to(torch.float)).item()
            correct_val += correct
    return val_loss, correct_val, total_val

In [None]:
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
def train_ingredient_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    best_val_loss = float('inf')
    patience = 7
    no_improvement = 0
    for epoch in range(num_epochs):
        start_time = time.time()
        model.train()
        
        train_loss = 0.0
        correct_train = 0
        total_train = 0

        val_loss = 0.0
        correct_val = 0
        total_val = 0

        train_loss, correct_train, total_train = train_loop(train_loader, criterion, optimizer, train_loss, correct_train, total_train)
        avg_train_loss = train_loss / len(train_loader)
        train_accuracy = (correct_train / total_train) * 100.0
        train_accuracies.append(train_accuracy)
        train_losses.append(avg_train_loss)

        val_loss, correct_val, total_val = evaluate(val_loader, criterion, optimizer, val_loss, correct_val, total_val)

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = (correct_val / total_val) * 100.0
        val_accuracies.append(val_accuracy)
        val_losses.append(avg_val_loss)
        lr_scheduler.step(avg_val_loss)

        end_time = time.time()
        epoch_time = end_time - start_time

        print(f"Epoch [{epoch+1}/{num_epochs}] - Time: {epoch_time:.2f} seconds\n"
            f"Train Accuracy: {train_accuracy:.2f}% - Train Loss: {avg_train_loss:.4f}\n"
            f"Validation Accuracy: {val_accuracy:.2f}% - Validation Loss: {avg_val_loss:.4f}")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            no_improvement = 0
        else:
            no_improvement += 1

        if no_improvement == patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
train_ingredient_model(model, train_loader, val_loader, criterion, optimizer, 60)

# Result

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

ax1.plot(train_losses, label='Training') 
ax1.plot(val_losses, label='Validation')
ax1.legend()
ax1.set_title('Loss')

ax2.plot(train_accuracies, label='Training')
ax2.plot(val_accuracies, label='Validation') 
ax2.legend()
ax2.set_title('Accuracy')
ax2.set_ylim(0, 100)
plt.show()

In [None]:
inputs, targets = next(iter(test_loader))
inputs, targets = inputs.float().to(mps_device), targets.to(mps_device)

# Make predictions 
outputs = model(inputs)
fig, (m_axs) = plt.subplots(10, 3, figsize=(30, 45), gridspec_kw={'hspace': 1.2, 'wspace': 1.2})
for i, c_ax in enumerate(m_axs.flatten()):
    pred_title = ', '.join(['{} ({:2.1f}%)\n'.format(allergens[j], 100 * torch.sigmoid(outputs[i, j]).item())
                            for j, v in enumerate(outputs[i])
                            if torch.sigmoid(v) > 0.5])
    wrapped_pred = '\n'.join(textwrap.wrap(pred_title, width=50))

    act_title = ', '.join(['{}'.format(allergens[j]) for j, v in enumerate(targets[i]) if v == 1.0])
    wrapped_act = '\n'.join(textwrap.wrap(act_title, width=50))
    
    img = inputs[i].permute(1, 2, 0).cpu()
    c_ax.imshow(img)
    title = "Predicted: {}\nActual: {}".format(wrapped_pred, wrapped_act)
    c_ax.set_title(title)

In [None]:
def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):
    acc_list = []
    for i in range(y_true.shape[0]):
        set_true = set( np.where(y_true[i])[0] )
        set_pred = set( np.where(y_pred[i])[0] )
        tmp_a = None
        if len(set_true) == 0 and len(set_pred) == 0:
            tmp_a = 1
        else:
            tmp_a = len(set_true.intersection(set_pred))/\
                    float( len(set_true.union(set_pred)) )
        acc_list.append(tmp_a)
    return np.mean(acc_list)

In [None]:
for param in model.parameters():
    param.requires_grad = False
true_labels = []
predicted_labels = []
model.eval()
with torch.no_grad():
    for images, labels in tqdm.tqdm(test_loader, desc='Testing'):
        images, labels = images.to(mps_device), labels.to(mps_device)
        predictions = model(images)
        
        binary_predictions = (predictions > 0.6).cpu().numpy().astype(int)
        
        predicted_labels.extend(binary_predictions)
        true_labels.extend(labels.cpu().numpy())
true_labels = np.array(true_labels)
predicted_labels = np.array(predicted_labels)

report = classification_report(true_labels, predicted_labels, target_names=allergens, zero_division=0)
print(report)

f1score_samples = f1_score(y_true=true_labels, y_pred=predicted_labels, average='samples')
f1score_macro = f1_score(y_true=true_labels, y_pred=predicted_labels, average='macro')
f1score_weighted = f1_score(y_true=true_labels, y_pred=predicted_labels, average='weighted')
recall = recall_score(y_true=true_labels, y_pred=predicted_labels, average='samples')
prec = precision_score(y_true=true_labels, y_pred=predicted_labels, average='samples')
hamming = hamming_score(y_true=true_labels, y_pred=predicted_labels)


accuracy = accuracy_score(true_labels, predicted_labels)
hl = hamming_loss(true_labels, predicted_labels)
print("Accuracy: ", accuracy)
print("F1 Samples: ", f1score_samples)
print("F1 Weighted: ", f1score_weighted)
print("Hamming score: ", hamming)
print("Hamming loss: ", hl)