In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn


from sklearn.utils import class_weight

from torch.utils.data import DataLoader
from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
from torch.optim import Adam

from tqdm.notebook import tqdm

from cnn_utils import (
    SealDataset,
    display_result_metrics, 
    generate_predictions_pytorch,
    get_labels_and_sub_images, 
    get_labels, 
)

In [None]:
# Connect to the GPU if one exists.
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using: ", device)
torch.cuda.empty_cache()

In [None]:
# Load Model with pre-trained weights
efficientnet = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.DEFAULT)

# Add final layer to predict 2 classes
efficientnet.classifier[1] = nn.Linear(in_features=1280, out_features=2)

# Put on device
efficientnet = efficientnet.to(device)

In [None]:
# Load Data
data_path = "../Data"

train_img_data = np.load(f"{data_path}/train_images.npy", allow_pickle=True)
train_bb_data = np.load(f"{data_path}/train_bb_data.npy", allow_pickle=True)

val_img_data  = np.load(f"{data_path}/val_images.npy", allow_pickle=True)
val_bb_data = np.load(f"{data_path}/val_bb_data.npy", allow_pickle=True)

In [None]:
seal_threshold = .3

# Separate the images and image data between seal and no seal
label_1_img, label_1, label_0_img, label_0 = get_labels_and_sub_images(train_img_data, train_bb_data, threshold=seal_threshold)

# Get the labels for the validation data
val_label = get_labels(val_bb_data, seal_threshold)

In [None]:
transfomed_images = []
transformed_labels = []

#Add mirror images and horizontal flip
for sub_image in label_1_img:
    
    # Apply tranformations
    mirrored_image = np.fliplr(sub_image)
    horizontal_flipped_image = np.flipud(sub_image)

    # Add new data
    transfomed_images.append(mirrored_image)
    transfomed_images.append(horizontal_flipped_image)

    transformed_labels.append(1)
    transformed_labels.append(1)

# Combine images
label_1_img += transfomed_images
label_1 += transformed_labels

In [None]:
total_labels = np.array(label_1 + label_0)
total_images = np.array(label_1_img + label_0_img)

In [None]:
# Convert to Tensors
total_labels_tensors = torch.tensor(total_labels, dtype=torch.long)
val_label_tensors = torch.tensor(val_label, dtype=torch.long)

# Create Dataset objects
train_data = SealDataset(total_images, total_labels_tensors)
valid_data = SealDataset(val_img_data, val_label_tensors)

# Instantiate DataLoader
train_loader = DataLoader(dataset = train_data, batch_size=10, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=50, shuffle=True)

In [None]:
# Calculate Class weights
unique_classes = np.unique(total_labels)
weights = class_weight.compute_class_weight(
    "balanced", 
    classes=unique_classes, 
    y=total_labels
)
class_weights = torch.FloatTensor(weights).to(device)

In [None]:
# Other model parameters
optimizer = Adam(efficientnet.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss(weight=class_weights)

In [None]:
def train_model(model, epochs:int, opt, loss, dataloader:DataLoader, device:str) -> None:
    model.train()
    N = len(dataloader.dataset)

    for epoch_num in range(epochs):
        # Display metrics
        total_loss = 0
        correct = 0
        total_seen = 0
        loop = tqdm(dataloader)

        # Train each batch
        for data, label in loop:
            # Put data and label on device
            data = data.to(device)
            label = label.type(torch.FloatTensor).to(device)

            # Forward propogate
            opt.zero_grad()
            yhat = model(data).type(torch.FloatTensor).to(device)
            loss = loss_fn(yhat, label)

            # Calculate metrtics for batch (For display)
            total_loss += abs(loss.item())
            for i in range(len(yhat)):
                
                # Check if correctly predicted seal
                if yhat[i][1] > .5 and label[i][1] == 1:
                    correct += 1

                # Check if correctly predicted no seal
                elif yhat[i][0] > .5 and label[i][0] == 1:
                    correct += 1

                total_seen += 1
            
            # Backwards propogate
            loss.backward()
            opt.step()

            # Display batch metrics
            loop.set_description("Epoch: {}      Accuracy: {}      MAE: {}      ".format(epoch_num + 1, round(correct/total_seen, 4), round(total_loss/total_seen, 4)))
            loop.refresh()

        # Display Epoch metrics
        print(f"Epoch:{epoch_num + 1} MAE:{total_loss / N} Accuracy: {correct / N}")

In [None]:
train_model(efficientnet, 1, optimizer, loss_fn, train_loader, device)

In [None]:
# Evaluate model
actual_labels, predicted_labels = generate_predictions_pytorch(efficientnet, valid_loader, device)

display_result_metrics(actual_labels, predicted_labels)

In [None]:
# Save model
model_name = "cnn_efficient_net"

torch.save(efficientnet.state_dict(), f"../Models/PyTorch/{model_name}")