In [1]:
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torchvision.models as models
from tqdm import tqdm
from torchvision import transforms
from sklearn.metrics import roc_auc_score

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [34]:
# Load the data
df = pd.read_csv('/content/drive/My Drive/data/train data/metadata.csv')
df.head()

Unnamed: 0,date,id_coord,plume,set,lat,lon,coord_x,coord_y,path
0,20230223,id_6675,yes,train,31.52875,74.330625,24,47,images/plume/20230223_methane_mixing_ratio_id_...
1,20230103,id_2542,yes,train,35.538,112.524,42,37,images/plume/20230103_methane_mixing_ratio_id_...
2,20230301,id_6546,yes,train,21.06,84.936667,58,15,images/plume/20230301_methane_mixing_ratio_id_...
3,20230225,id_6084,yes,train,26.756667,80.973333,28,62,images/plume/20230225_methane_mixing_ratio_id_...
4,20230105,id_2012,yes,train,34.8,40.77,59,44,images/plume/20230105_methane_mixing_ratio_id_...


In [35]:
# Split the data
train_df, valid_df = train_test_split(df, train_size =0.8, test_size=0.2, random_state=42)

In [36]:
# extract file names from path
train_df['new_path'] = train_df['path'].str.replace(r'images/(plume|no_plume)/', '', regex=True) + ".tif"
valid_df['new_path'] = valid_df['path'].str.replace(r'images/(plume|no_plume)/', '', regex=True) + ".tif"

In [37]:
#For each train and validation datasets seggregate data into separate folders based on plume or no plume
import os
import random
import shutil

# Set up paths
data_dir = '/content/drive/My Drive/data/train data/images'
train_dir= 'train'
val_dir = 'validation'

plume_dir = 'plume'
no_plume_dir = 'no_plume'

# Create validation directories
os.makedirs(os.path.join(train_dir, plume_dir), exist_ok=True)
os.makedirs(os.path.join(train_dir, no_plume_dir), exist_ok=True)
os.makedirs(os.path.join(val_dir, plume_dir), exist_ok=True)
os.makedirs(os.path.join(val_dir, no_plume_dir), exist_ok=True)

for index, row in train_df.iterrows():
    image_file = row['new_path']
    is_plume = row['plume']

    if is_plume == 'yes':
      if not os.path.exists(os.path.join(train_dir, plume_dir, image_file)):
        shutil.move(os.path.join(data_dir, plume_dir, image_file),
                os.path.join(train_dir, plume_dir))
    else:
      if not os.path.exists(os.path.join(train_dir, no_plume_dir, image_file)):
        shutil.move(os.path.join(data_dir, no_plume_dir, image_file),
                os.path.join(train_dir, no_plume_dir))

for index, row in valid_df.iterrows():
    image_file = row['new_path']
    is_plume = row['plume']

    if is_plume == 'yes':
      if not os.path.exists(os.path.join(val_dir, plume_dir, image_file)):
        shutil.move(os.path.join(data_dir, plume_dir, image_file),
                os.path.join(val_dir, plume_dir))
    else:
      if not os.path.exists(os.path.join(val_dir, no_plume_dir, image_file)):
        shutil.move(os.path.join(data_dir, no_plume_dir, image_file),
                os.path.join(val_dir, no_plume_dir))


In [38]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, augment=False):
        self.root_dir = root_dir
        self.augment = augment
        self.image_list = []
        self.labels = []
        self.augmented_images = []

        plume_images = []
        no_plume_images = []

        classes = os.listdir(root_dir)
        for class_name in classes:
            class_dir = os.path.join(root_dir, class_name)
            if os.path.isdir(class_dir):
                image_names = os.listdir(class_dir)
                tiff_image_names = [img_name for img_name in image_names if img_name.lower().endswith('.tiff') or img_name.lower().endswith('.tif')]

                if class_name == 'plume':
                    plume_images.extend([os.path.join(class_dir, img_name) for img_name in tiff_image_names])
                else:
                    no_plume_images.extend([os.path.join(class_dir, img_name) for img_name in tiff_image_names])

                self.labels.extend([1 if class_name == 'plume' else 0] * len(tiff_image_names))

        # Calculate the number of images in each class
        num_images = min(len(plume_images), len(no_plume_images))

        if augment:
            augmented_images = []
            for _ in range(num_images):
                # Randomly select plume and no plume images
                plume_image_path = random.choice(plume_images)
                no_plume_image_path = random.choice(no_plume_images)

                # Load and augment plume image
                plume_image = Image.open(plume_image_path)
                plume_image = plume_image.resize((64, 64))  # Resize to 64x64
                transform = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomRotation(10),
                    transforms.ToTensor()
                ])
                augmented_plume_image = transform(plume_image)
                augmented_images.append(augmented_plume_image)

                # Load and augment no plume image
                no_plume_image = Image.open(no_plume_image_path)
                no_plume_image = no_plume_image.resize((64, 64))  # Resize to 64x64
                augmented_no_plume_image = transform(no_plume_image)
                augmented_images.append(augmented_no_plume_image)

            self.augmented_images = augmented_images

        self.image_list = plume_images[:num_images] + no_plume_images[:num_images]
        self.labels = [1] * num_images + [0] * num_images

        if augment:
            num_augmented_images = len(self.augmented_images)
            self.image_list.extend(self.augmented_images)
            self.labels.extend(self.labels[:num_augmented_images])

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = self.image_list[idx]
        label = self.labels[idx]

        if isinstance(image, str):
            image = Image.open(image)
            image = image.resize((64, 64))  # Resize to 64x64

            transform = transforms.Compose([
                transforms.ToTensor()
            ])
            image = transform(image)

        return image, label

In [39]:
root_dir = "train"
val_dir = "validation"

train_dataset = CustomDataset(root_dir, augment=True)
val_dataset = CustomDataset(val_dir, augment=False)


# Define batch size for training and validation
batch_size = 32

# Create the train DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Create the validation DataLoader
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Verify the train and validation data loaders
for images, labels in train_loader:
    print(f"Train batch - Images: {images.shape}, Labels: {labels.shape}")
    break

for images, labels in val_loader:
    print(f"Validation batch - Images: {images.shape}, Labels: {labels.shape}")
    break

Train batch - Images: torch.Size([32, 1, 64, 64]), Labels: torch.Size([32])
Validation batch - Images: torch.Size([32, 1, 64, 64]), Labels: torch.Size([32])


In [40]:
train_dataset

<__main__.CustomDataset at 0x7c42792c39a0>

In [13]:
#We can check that our train dataset is balanced between plume and no plume
num_plumes = sum(label == 1 for label in train_dataset.labels)
num_no_plumes = sum(label == 0 for label in train_dataset.labels)

print("Number of training samples:", len(train_dataset))
print(f"Number of plumes: {num_plumes}")
print(f"Number of no plumes: {num_no_plumes}")

Number of training samples: 568
Number of plumes: 284
Number of no plumes: 284


#WideResNet50

In [10]:
# Define the device for training (CPU or GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained ResNet50 model
resnet50 = models.wide_resnet50_2(weights='Wide_ResNet50_2_Weights.IMAGENET1K_V2')

# Modify the first layer to accept single-channel grayscale images
resnet50.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Modify the last fully connected layer for binary classification with softmax activation
num_classes = 2  # 2 classes: 1 or 0
resnet50.fc = nn.Sequential(
    nn.Linear(resnet50.fc.in_features, num_classes)
)

Downloading: "https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth" to /root/.cache/torch/hub/checkpoints/wide_resnet50_2-9ba9bcbe.pth
100%|██████████| 263M/263M [00:23<00:00, 11.7MB/s]


In [14]:
# Move the model to the device
resnet50 = resnet50.to(device)

# Define the loss function (criterion)
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.Adam(resnet50.parameters(), lr=0.001)

num_epochs = 50

In [15]:
best_auc = 0.0
# Training loop
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    resnet50.train()  # Set the model to training mode

    epoch_loss = 0.0  # Accumulator for epoch loss
    for images, labels in train_loader:
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device)

        # Forward pass
        outputs = resnet50(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate the loss
        epoch_loss += loss.item()

    # Compute the average loss for the epoch
    epoch_loss /= len(train_loader)
    tqdm.write(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {epoch_loss}")

    # Validation loop
    resnet50.eval()  # Set the model to evaluation mode

    true_probabilities = []
    predicted_probabilities = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, dtype=torch.float)
            labels = labels.to(device)

            # Forward pass
            outputs = resnet50(images)
            probabilities = torch.softmax(outputs, dim=1)

            # Collect predicted probabilities and true labels
            predicted_probabilities.extend(probabilities[:, 1].cpu().numpy())  # Assuming binary classification
            true_probabilities.extend(labels.cpu().numpy())

    # Compute the AUC
    auc = roc_auc_score(true_probabilities, predicted_probabilities)
    print(f"Validation AUC: {auc}")

    # Check if current accuracy is better than the previous best accuracy
    if auc > best_auc:
        best_auc = auc
        torch.save(resnet50.state_dict(), f"resnet50_{best_auc:.4f}.pth")  # Save the model with AUC in the filename
        print("Saved")

Epochs:   0%|          | 0/50 [00:02<?, ?it/s]

Epoch 1/50 - Average Loss: 0.2112888248844279
Validation AUC: 0.7691326530612245


Epochs:   2%|▏         | 1/50 [00:02<02:15,  2.77s/it]

Saved


Epochs:   4%|▍         | 2/50 [00:04<01:51,  2.33s/it]

Epoch 2/50 - Average Loss: 0.16391508684804043
Validation AUC: 0.7461734693877552


Epochs:   4%|▍         | 2/50 [00:06<01:51,  2.33s/it]

Epoch 3/50 - Average Loss: 0.052294033747683794
Validation AUC: 0.8134566326530612


Epochs:   6%|▌         | 3/50 [00:07<02:03,  2.63s/it]

Saved


Epochs:   8%|▊         | 4/50 [00:09<01:51,  2.42s/it]

Epoch 4/50 - Average Loss: 0.05298114856446369
Validation AUC: 0.7699298469387755


Epochs:  10%|█         | 5/50 [00:11<01:42,  2.28s/it]

Epoch 5/50 - Average Loss: 0.03824260899434901
Validation AUC: 0.7517538265306123


Epochs:  12%|█▏        | 6/50 [00:13<01:36,  2.20s/it]

Epoch 6/50 - Average Loss: 0.05090522672010896
Validation AUC: 0.8102678571428571


Epochs:  14%|█▍        | 7/50 [00:15<01:32,  2.14s/it]

Epoch 7/50 - Average Loss: 0.16462485141689992
Validation AUC: 0.7189094387755102


Epochs:  16%|█▌        | 8/50 [00:18<01:28,  2.11s/it]

Epoch 8/50 - Average Loss: 0.18545603131254515
Validation AUC: 0.6954719387755102


Epochs:  18%|█▊        | 9/50 [00:20<01:26,  2.12s/it]

Epoch 9/50 - Average Loss: 0.15568828541371557
Validation AUC: 0.7566964285714286


Epochs:  20%|██        | 10/50 [00:22<01:24,  2.12s/it]

Epoch 10/50 - Average Loss: 0.13939469028264284
Validation AUC: 0.6455676020408163


Epochs:  22%|██▏       | 11/50 [00:24<01:22,  2.10s/it]

Epoch 11/50 - Average Loss: 0.11549520896126826
Validation AUC: 0.7761479591836734


Epochs:  24%|██▍       | 12/50 [00:26<01:19,  2.09s/it]

Epoch 12/50 - Average Loss: 0.05553787750088506
Validation AUC: 0.7917729591836735


Epochs:  26%|██▌       | 13/50 [00:28<01:17,  2.08s/it]

Epoch 13/50 - Average Loss: 0.030141308237539813
Validation AUC: 0.7404336734693878


Epochs:  28%|██▊       | 14/50 [00:30<01:14,  2.08s/it]

Epoch 14/50 - Average Loss: 0.03154011699047664
Validation AUC: 0.7578125


Epochs:  30%|███       | 15/50 [00:32<01:13,  2.09s/it]

Epoch 15/50 - Average Loss: 0.06033480777599228
Validation AUC: 0.7241709183673469


Epochs:  32%|███▏      | 16/50 [00:34<01:11,  2.11s/it]

Epoch 16/50 - Average Loss: 0.0715346976260965
Validation AUC: 0.7551020408163266


Epochs:  34%|███▍      | 17/50 [00:36<01:09,  2.11s/it]

Epoch 17/50 - Average Loss: 0.1339857484239878
Validation AUC: 0.7653061224489796


Epochs:  36%|███▌      | 18/50 [00:38<01:06,  2.09s/it]

Epoch 18/50 - Average Loss: 0.24724194552335474
Validation AUC: 0.7346938775510204


Epochs:  36%|███▌      | 18/50 [00:40<01:06,  2.09s/it]

Epoch 19/50 - Average Loss: 0.10314832317332427
Validation AUC: 0.8163265306122449


Epochs:  38%|███▊      | 19/50 [00:41<01:09,  2.25s/it]

Saved


Epochs:  40%|████      | 20/50 [00:43<01:05,  2.19s/it]

Epoch 20/50 - Average Loss: 0.054339351164849684
Validation AUC: 0.8061224489795918


Epochs:  40%|████      | 20/50 [00:45<01:05,  2.19s/it]

Epoch 21/50 - Average Loss: 0.05930237675784156
Validation AUC: 0.8485331632653061


Epochs:  42%|████▏     | 21/50 [00:46<01:09,  2.38s/it]

Saved


Epochs:  44%|████▍     | 22/50 [00:49<01:12,  2.57s/it]

Epoch 22/50 - Average Loss: 0.04488101098427756
Validation AUC: 0.8191964285714286


Epochs:  46%|████▌     | 23/50 [00:51<01:05,  2.42s/it]

Epoch 23/50 - Average Loss: 0.029096867599744454
Validation AUC: 0.7863520408163266


Epochs:  48%|████▊     | 24/50 [00:53<00:59,  2.30s/it]

Epoch 24/50 - Average Loss: 0.03703818429817653
Validation AUC: 0.7911352040816326


Epochs:  50%|█████     | 25/50 [00:55<00:55,  2.22s/it]

Epoch 25/50 - Average Loss: 0.02439997962857079
Validation AUC: 0.7860331632653061


Epochs:  52%|█████▏    | 26/50 [00:57<00:52,  2.17s/it]

Epoch 26/50 - Average Loss: 0.018490212720482506
Validation AUC: 0.7802933673469387


Epochs:  54%|█████▍    | 27/50 [00:59<00:48,  2.13s/it]

Epoch 27/50 - Average Loss: 0.02005995724013903
Validation AUC: 0.7662627551020409


Epochs:  56%|█████▌    | 28/50 [01:01<00:46,  2.12s/it]

Epoch 28/50 - Average Loss: 0.016238772794571962
Validation AUC: 0.7822066326530612


Epochs:  58%|█████▊    | 29/50 [01:03<00:44,  2.11s/it]

Epoch 29/50 - Average Loss: 0.020635346976632718
Validation AUC: 0.7847576530612244


Epochs:  60%|██████    | 30/50 [01:05<00:41,  2.09s/it]

Epoch 30/50 - Average Loss: 0.017548926211828884
Validation AUC: 0.7777423469387755


Epochs:  62%|██████▏   | 31/50 [01:07<00:39,  2.07s/it]

Epoch 31/50 - Average Loss: 0.01612096231787695
Validation AUC: 0.7761479591836734


Epochs:  64%|██████▍   | 32/50 [01:09<00:37,  2.06s/it]

Epoch 32/50 - Average Loss: 0.017033371887616038
Validation AUC: 0.779655612244898


Epochs:  66%|██████▌   | 33/50 [01:12<00:34,  2.05s/it]

Epoch 33/50 - Average Loss: 0.017418550696068753
Validation AUC: 0.7691326530612245


Epochs:  68%|██████▊   | 34/50 [01:14<00:32,  2.06s/it]

Epoch 34/50 - Average Loss: 0.016598824533098702
Validation AUC: 0.7697704081632654


Epochs:  70%|███████   | 35/50 [01:16<00:31,  2.08s/it]

Epoch 35/50 - Average Loss: 0.01661910585648406
Validation AUC: 0.7904974489795918


Epochs:  72%|███████▏  | 36/50 [01:18<00:29,  2.08s/it]

Epoch 36/50 - Average Loss: 0.017149588477473623
Validation AUC: 0.7936862244897959


Epochs:  74%|███████▍  | 37/50 [01:20<00:26,  2.06s/it]

Epoch 37/50 - Average Loss: 0.016518796980158693
Validation AUC: 0.7876275510204083


Epochs:  76%|███████▌  | 38/50 [01:22<00:24,  2.06s/it]

Epoch 38/50 - Average Loss: 0.015926250788500813
Validation AUC: 0.786670918367347


Epochs:  78%|███████▊  | 39/50 [01:24<00:22,  2.05s/it]

Epoch 39/50 - Average Loss: 0.016584965917192877
Validation AUC: 0.7850765306122448


Epochs:  80%|████████  | 40/50 [01:26<00:20,  2.05s/it]

Epoch 40/50 - Average Loss: 0.016277100799874863
Validation AUC: 0.7879464285714285


Epochs:  82%|████████▏ | 41/50 [01:28<00:18,  2.07s/it]

Epoch 41/50 - Average Loss: 0.016977840915792006
Validation AUC: 0.7901785714285714


Epochs:  84%|████████▍ | 42/50 [01:30<00:16,  2.08s/it]

Epoch 42/50 - Average Loss: 0.016321240548980615
Validation AUC: 0.7841198979591836


Epochs:  86%|████████▌ | 43/50 [01:32<00:14,  2.06s/it]

Epoch 43/50 - Average Loss: 0.01570254700651377
Validation AUC: 0.7796556122448979


Epochs:  88%|████████▊ | 44/50 [01:34<00:12,  2.05s/it]

Epoch 44/50 - Average Loss: 0.015433444151666562
Validation AUC: 0.7790178571428571


Epochs:  90%|█████████ | 45/50 [01:36<00:10,  2.05s/it]

Epoch 45/50 - Average Loss: 0.015508680593736952
Validation AUC: 0.7815688775510204


Epochs:  92%|█████████▏| 46/50 [01:38<00:08,  2.04s/it]

Epoch 46/50 - Average Loss: 0.015575528000833097
Validation AUC: 0.782844387755102


Epochs:  94%|█████████▍| 47/50 [01:40<00:06,  2.06s/it]

Epoch 47/50 - Average Loss: 0.015778762302817186
Validation AUC: 0.7825255102040816


Epochs:  96%|█████████▌| 48/50 [01:43<00:04,  2.08s/it]

Epoch 48/50 - Average Loss: 0.01630731696801983
Validation AUC: 0.7834821428571429


Epochs:  98%|█████████▊| 49/50 [01:45<00:02,  2.07s/it]

Epoch 49/50 - Average Loss: 0.015802534648527298
Validation AUC: 0.78125


Epochs: 100%|██████████| 50/50 [01:47<00:00,  2.14s/it]

Epoch 50/50 - Average Loss: 0.016014952371885254
Validation AUC: 0.7838010204081634





###SGD with momentum and LR scheduler - ❌

In [20]:
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
optimizer = SGD(resnet50.parameters(), lr=0.01, momentum=0.9)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)

In [21]:
num_epochs = 40
from sklearn.metrics import classification_report, f1_score, roc_auc_score
best_auc = 0.0
# Training loop
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    resnet50.train()  # Set the model to training mode

    epoch_loss = 0.0  # Accumulator for epoch loss
    for images, labels in train_loader:
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device)

        # Forward pass
        outputs = resnet50(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate the loss
        epoch_loss += loss.item()

    # Compute the average loss for the epoch
    epoch_loss /= len(train_loader)
    tqdm.write(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {epoch_loss}")

    scheduler.step()

    # Validation loop
    resnet50.eval()  # Set the model to evaluation mode

    true_probabilities = []
    predicted_probabilities = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, dtype=torch.float)
            labels = labels.to(device)

            # Forward pass
            outputs = resnet50(images)
            probabilities = torch.softmax(outputs, dim=1)

            # Collect predicted probabilities and true labels
            predicted_probabilities.extend(probabilities[:, 1].cpu().numpy())  # Assuming binary classification
            true_probabilities.extend(labels.cpu().numpy())

    # Compute the AUC
    auc = roc_auc_score(true_probabilities, predicted_probabilities)
    print(f"Validation AUC: {auc}")

    # Check if current accuracy is better than the previous best accuracy
    if auc > best_auc:
        best_auc = auc
        torch.save(resnet50.state_dict(), f"resnet50_{best_auc:.4f}.pth")  # Save the model with AUC in the filename
        print("Saved")

Epochs:   0%|          | 0/40 [00:01<?, ?it/s]

Epoch 1/40 - Average Loss: 0.015233904148582647
Validation AUC: 0.7815688775510204


Epochs:   2%|▎         | 1/40 [00:02<01:40,  2.57s/it]

Saved


Epochs:   5%|▌         | 2/40 [00:04<01:21,  2.15s/it]

Epoch 2/40 - Average Loss: 0.015856379969288053
Validation AUC: 0.7799744897959183


Epochs:   8%|▊         | 3/40 [00:06<01:14,  2.01s/it]

Epoch 3/40 - Average Loss: 0.015577323773666447
Validation AUC: 0.7777423469387754


Epochs:   8%|▊         | 3/40 [00:07<01:14,  2.01s/it]

Epoch 4/40 - Average Loss: 0.015986206741824088
Validation AUC: 0.7841198979591837


Epochs:  10%|█         | 4/40 [00:08<01:18,  2.18s/it]

Saved


Epochs:  12%|█▎        | 5/40 [00:10<01:12,  2.07s/it]

Epoch 5/40 - Average Loss: 0.015615227152920852
Validation AUC: 0.7791772959183674


Epochs:  15%|█▌        | 6/40 [00:12<01:09,  2.03s/it]

Epoch 6/40 - Average Loss: 0.015498931793394554
Validation AUC: 0.7777423469387755


Epochs:  18%|█▊        | 7/40 [00:14<01:06,  2.01s/it]

Epoch 7/40 - Average Loss: 0.015290515153386272
Validation AUC: 0.7806122448979592


Epochs:  20%|██        | 8/40 [00:16<01:02,  1.96s/it]

Epoch 8/40 - Average Loss: 0.016058539702044072
Validation AUC: 0.7815688775510203


Epochs:  22%|██▎       | 9/40 [00:18<00:59,  1.93s/it]

Epoch 9/40 - Average Loss: 0.015404441515430436
Validation AUC: 0.7802933673469388


Epochs:  25%|██▌       | 10/40 [00:20<00:57,  1.91s/it]

Epoch 10/40 - Average Loss: 0.015450016248425728
Validation AUC: 0.7751913265306122


Epochs:  28%|██▊       | 11/40 [00:21<00:54,  1.89s/it]

Epoch 11/40 - Average Loss: 0.015400607653444039
Validation AUC: 0.7780612244897959


Epochs:  30%|███       | 12/40 [00:23<00:52,  1.88s/it]

Epoch 12/40 - Average Loss: 0.0164850149047753
Validation AUC: 0.7806122448979592


Epochs:  32%|███▎      | 13/40 [00:25<00:51,  1.90s/it]

Epoch 13/40 - Average Loss: 0.015472065638884184
Validation AUC: 0.7809311224489797


Epochs:  35%|███▌      | 14/40 [00:27<00:49,  1.91s/it]

Epoch 14/40 - Average Loss: 0.015821158120186536
Validation AUC: 0.7802933673469388


Epochs:  38%|███▊      | 15/40 [00:29<00:47,  1.89s/it]

Epoch 15/40 - Average Loss: 0.015538067717879281
Validation AUC: 0.7818877551020408


Epochs:  38%|███▊      | 15/40 [00:31<00:47,  1.89s/it]

Epoch 16/40 - Average Loss: 0.015401908254539699
Validation AUC: 0.7844387755102041


Epochs:  40%|████      | 16/40 [00:31<00:49,  2.06s/it]

Saved


Epochs:  42%|████▎     | 17/40 [00:33<00:45,  1.99s/it]

Epoch 17/40 - Average Loss: 0.01621469552547852
Validation AUC: 0.7793367346938775


Epochs:  45%|████▌     | 18/40 [00:35<00:42,  1.94s/it]

Epoch 18/40 - Average Loss: 0.01581512971718742
Validation AUC: 0.7809311224489797


Epochs:  48%|████▊     | 19/40 [00:37<00:40,  1.91s/it]

Epoch 19/40 - Average Loss: 0.016689801867957814
Validation AUC: 0.7815688775510204


Epochs:  48%|████▊     | 19/40 [00:39<00:40,  1.91s/it]

Epoch 20/40 - Average Loss: 0.015473098338652057


Epochs:  50%|█████     | 20/40 [00:39<00:39,  1.98s/it]

Validation AUC: 0.779655612244898


Epochs:  50%|█████     | 20/40 [00:42<00:39,  1.98s/it]

Epoch 21/40 - Average Loss: 0.015472363619374442
Validation AUC: 0.7771045918367347

Epochs:  52%|█████▎    | 21/40 [00:42<00:42,  2.24s/it]




Epochs:  55%|█████▌    | 22/40 [00:44<00:38,  2.14s/it]

Epoch 22/40 - Average Loss: 0.015547700046959613
Validation AUC: 0.7801339285714285


Epochs:  57%|█████▊    | 23/40 [00:46<00:35,  2.08s/it]

Epoch 23/40 - Average Loss: 0.015364627675378668
Validation AUC: 0.7799744897959184


Epochs:  60%|██████    | 24/40 [00:48<00:32,  2.03s/it]

Epoch 24/40 - Average Loss: 0.015506440256053288
Validation AUC: 0.7815688775510203


Epochs:  62%|██████▎   | 25/40 [00:50<00:29,  1.99s/it]

Epoch 25/40 - Average Loss: 0.015454071558148522
Validation AUC: 0.7831632653061225


Epochs:  65%|██████▌   | 26/40 [00:51<00:27,  1.95s/it]

Epoch 26/40 - Average Loss: 0.015926951203785695
Validation AUC: 0.7785395408163265


Epochs:  68%|██████▊   | 27/40 [00:53<00:25,  1.93s/it]

Epoch 27/40 - Average Loss: 0.01526881693712312
Validation AUC: 0.776466836734694


Epochs:  70%|███████   | 28/40 [00:55<00:23,  1.92s/it]

Epoch 28/40 - Average Loss: 0.016650380775887343
Validation AUC: 0.78125


Epochs:  72%|███████▎  | 29/40 [00:57<00:20,  1.89s/it]

Epoch 29/40 - Average Loss: 0.015666740373477397
Validation AUC: 0.7825255102040816


Epochs:  75%|███████▌  | 30/40 [00:59<00:18,  1.86s/it]

Epoch 30/40 - Average Loss: 0.015298973167419818
Validation AUC: 0.7809311224489796


Epochs:  78%|███████▊  | 31/40 [01:01<00:16,  1.85s/it]

Epoch 31/40 - Average Loss: 0.01591103294915936
Validation AUC: 0.78125


Epochs:  80%|████████  | 32/40 [01:03<00:14,  1.84s/it]

Epoch 32/40 - Average Loss: 0.015296316964578536
Validation AUC: 0.7793367346938775


Epochs:  82%|████████▎ | 33/40 [01:04<00:12,  1.83s/it]

Epoch 33/40 - Average Loss: 0.01573871091627805
Validation AUC: 0.7802933673469388


Epochs:  85%|████████▌ | 34/40 [01:06<00:11,  1.84s/it]

Epoch 34/40 - Average Loss: 0.015698906542259767
Validation AUC: 0.7809311224489797


Epochs:  88%|████████▊ | 35/40 [01:08<00:09,  1.86s/it]

Epoch 35/40 - Average Loss: 0.015391446139536432
Validation AUC: 0.7809311224489796


Epochs:  90%|█████████ | 36/40 [01:10<00:07,  1.85s/it]

Epoch 36/40 - Average Loss: 0.015457555892377664
Validation AUC: 0.78125


Epochs:  92%|█████████▎| 37/40 [01:12<00:05,  1.84s/it]

Epoch 37/40 - Average Loss: 0.01571205532283986
Validation AUC: 0.7790178571428572


Epochs:  95%|█████████▌| 38/40 [01:14<00:03,  1.83s/it]

Epoch 38/40 - Average Loss: 0.015324722350543298
Validation AUC: 0.7790178571428572


Epochs:  98%|█████████▊| 39/40 [01:15<00:01,  1.82s/it]

Epoch 39/40 - Average Loss: 0.015949370959560838
Validation AUC: 0.7825255102040816


Epochs: 100%|██████████| 40/40 [01:17<00:00,  1.94s/it]

Epoch 40/40 - Average Loss: 0.015427154767924852
Validation AUC: 0.7825255102040816





###AdaGrad with L2 reg - ❌

In [27]:
import torch.optim as optim
optimizer = optim.Adagrad(resnet50.parameters(), lr=0.01)#, weight_decay=1e-4)

In [28]:
num_epochs = 40
from sklearn.metrics import classification_report, f1_score, roc_auc_score
best_auc = 0.0
# Training loop
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    resnet50.train()  # Set the model to training mode

    epoch_loss = 0.0  # Accumulator for epoch loss
    for images, labels in train_loader:
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device)

        # Forward pass
        outputs = resnet50(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate the loss
        epoch_loss += loss.item()

    # Compute the average loss for the epoch
    epoch_loss /= len(train_loader)
    tqdm.write(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {epoch_loss}")

    # Validation loop
    resnet50.eval()  # Set the model to evaluation mode

    true_probabilities = []
    predicted_probabilities = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, dtype=torch.float)
            labels = labels.to(device)

            # Forward pass
            outputs = resnet50(images)
            probabilities = torch.softmax(outputs, dim=1)

            # Collect predicted probabilities and true labels
            predicted_probabilities.extend(probabilities[:, 1].cpu().numpy())  # Assuming binary classification
            true_probabilities.extend(labels.cpu().numpy())

    # Compute the AUC
    auc = roc_auc_score(true_probabilities, predicted_probabilities)
    print(f"Validation AUC: {auc}")

    # Check if current accuracy is better than the previous best accuracy
    if auc > best_auc:
        best_auc = auc
        torch.save(resnet50.state_dict(), f"resnet50_{best_auc:.4f}.pth")  # Save the model with AUC in the filename
        print("Saved")

Epochs:   0%|          | 0/40 [00:01<?, ?it/s]

Epoch 1/40 - Average Loss: 1.7121212784614828
Validation AUC: 0.5653698979591837


Epochs:   2%|▎         | 1/40 [00:02<01:39,  2.56s/it]

Saved


Epochs:   2%|▎         | 1/40 [00:04<01:39,  2.56s/it]

Epoch 2/40 - Average Loss: 1.1882405115498438
Validation AUC: 0.5959821428571429


Epochs:   5%|▌         | 2/40 [00:05<01:42,  2.69s/it]

Saved


Epochs:   5%|▌         | 2/40 [00:07<01:42,  2.69s/it]

Epoch 3/40 - Average Loss: 0.9303947256671058
Validation AUC: 0.7551020408163266


Epochs:   8%|▊         | 3/40 [00:08<01:53,  3.05s/it]

Saved


Epochs:  10%|█         | 4/40 [00:11<01:37,  2.71s/it]

Epoch 4/40 - Average Loss: 0.7436023718780942
Validation AUC: 0.642219387755102


Epochs:  12%|█▎        | 5/40 [00:13<01:25,  2.45s/it]

Epoch 5/40 - Average Loss: 0.729595939318339
Validation AUC: 0.6613520408163265


Epochs:  12%|█▎        | 5/40 [00:14<01:25,  2.45s/it]

Epoch 6/40 - Average Loss: 0.7083308829201592
Validation AUC: 0.7665816326530611


Epochs:  15%|█▌        | 6/40 [00:15<01:24,  2.49s/it]

Saved


Epochs:  18%|█▊        | 7/40 [00:17<01:16,  2.32s/it]

Epoch 7/40 - Average Loss: 0.7216365999645658
Validation AUC: 0.7071109693877551


Epochs:  20%|██        | 8/40 [00:19<01:10,  2.20s/it]

Epoch 8/40 - Average Loss: 0.5493289861414168
Validation AUC: 0.75


Epochs:  22%|██▎       | 9/40 [00:21<01:05,  2.13s/it]

Epoch 9/40 - Average Loss: 0.5085370391607285
Validation AUC: 0.7318239795918368


Epochs:  25%|██▌       | 10/40 [00:23<01:03,  2.11s/it]

Epoch 10/40 - Average Loss: 0.4233457131518258
Validation AUC: 0.7042410714285714


Epochs:  28%|██▊       | 11/40 [00:25<01:00,  2.08s/it]

Epoch 11/40 - Average Loss: 0.3777909990813997
Validation AUC: 0.7366071428571429


Epochs:  30%|███       | 12/40 [00:27<00:57,  2.04s/it]

Epoch 12/40 - Average Loss: 0.3435763484901852
Validation AUC: 0.6932397959183674


Epochs:  32%|███▎      | 13/40 [00:29<00:54,  2.01s/it]

Epoch 13/40 - Average Loss: 0.2651255552967389
Validation AUC: 0.6938775510204082


Epochs:  35%|███▌      | 14/40 [00:31<00:51,  1.99s/it]

Epoch 14/40 - Average Loss: 0.21329489226142564
Validation AUC: 0.7173150510204082


Epochs:  38%|███▊      | 15/40 [00:33<00:49,  1.98s/it]

Epoch 15/40 - Average Loss: 0.15350163645214504
Validation AUC: 0.7377232142857142


Epochs:  40%|████      | 16/40 [00:35<00:47,  1.98s/it]

Epoch 16/40 - Average Loss: 0.14267724917994606
Validation AUC: 0.696906887755102


Epochs:  42%|████▎     | 17/40 [00:37<00:46,  2.00s/it]

Epoch 17/40 - Average Loss: 0.1236606819762124
Validation AUC: 0.7021683673469388


Epochs:  45%|████▌     | 18/40 [00:39<00:44,  2.00s/it]

Epoch 18/40 - Average Loss: 0.1301672582825025
Validation AUC: 0.6989795918367347


Epochs:  48%|████▊     | 19/40 [00:41<00:41,  1.99s/it]

Epoch 19/40 - Average Loss: 0.10664572959972753
Validation AUC: 0.7147640306122449


Epochs:  50%|█████     | 20/40 [00:43<00:39,  1.97s/it]

Epoch 20/40 - Average Loss: 0.07818232476711273
Validation AUC: 0.6377551020408163


Epochs:  52%|█████▎    | 21/40 [00:45<00:37,  1.97s/it]

Epoch 21/40 - Average Loss: 0.10281835848258601
Validation AUC: 0.6919642857142857


Epochs:  55%|█████▌    | 22/40 [00:47<00:35,  1.97s/it]

Epoch 22/40 - Average Loss: 0.09118753195636803
Validation AUC: 0.6886160714285714


Epochs:  57%|█████▊    | 23/40 [00:49<00:33,  1.98s/it]

Epoch 23/40 - Average Loss: 0.0540239844057295
Validation AUC: 0.6887755102040817


Epochs:  60%|██████    | 24/40 [00:51<00:32,  2.00s/it]

Epoch 24/40 - Average Loss: 0.061380356239775814
Validation AUC: 0.7168367346938775


Epochs:  62%|██████▎   | 25/40 [00:53<00:29,  2.00s/it]

Epoch 25/40 - Average Loss: 0.08950176591881448
Validation AUC: 0.7330994897959184


Epochs:  65%|██████▌   | 26/40 [00:55<00:27,  1.98s/it]

Epoch 26/40 - Average Loss: 0.07272260227344102
Validation AUC: 0.6916454081632653


Epochs:  68%|██████▊   | 27/40 [00:57<00:25,  1.97s/it]

Epoch 27/40 - Average Loss: 0.06823032318303983
Validation AUC: 0.7436224489795918


Epochs:  70%|███████   | 28/40 [00:59<00:23,  1.97s/it]

Epoch 28/40 - Average Loss: 0.08634297870513466
Validation AUC: 0.6846301020408162


Epochs:  72%|███████▎  | 29/40 [01:01<00:21,  1.96s/it]

Epoch 29/40 - Average Loss: 0.04537510651991599
Validation AUC: 0.6568877551020408


Epochs:  75%|███████▌  | 30/40 [01:03<00:19,  1.97s/it]

Epoch 30/40 - Average Loss: 0.052101462691401444
Validation AUC: 0.6237244897959183


Epochs:  78%|███████▊  | 31/40 [01:05<00:18,  2.00s/it]

Epoch 31/40 - Average Loss: 0.04082259957471655
Validation AUC: 0.6616709183673469


Epochs:  80%|████████  | 32/40 [01:07<00:15,  1.99s/it]

Epoch 32/40 - Average Loss: 0.052013075807028346
Validation AUC: 0.6890943877551021


Epochs:  82%|████████▎ | 33/40 [01:09<00:13,  1.99s/it]

Epoch 33/40 - Average Loss: 0.047276974086546235
Validation AUC: 0.6941964285714285


Epochs:  85%|████████▌ | 34/40 [01:10<00:11,  1.98s/it]

Epoch 34/40 - Average Loss: 0.05203561500335733
Validation AUC: 0.6948341836734694


Epochs:  88%|████████▊ | 35/40 [01:12<00:09,  1.97s/it]

Epoch 35/40 - Average Loss: 0.03982804728568428
Validation AUC: 0.6951530612244897


Epochs:  90%|█████████ | 36/40 [01:14<00:07,  1.97s/it]

Epoch 36/40 - Average Loss: 0.034342571161687374
Validation AUC: 0.6753826530612245


Epochs:  92%|█████████▎| 37/40 [01:16<00:05,  1.99s/it]

Epoch 37/40 - Average Loss: 0.037908891591036484
Validation AUC: 0.6897321428571428


Epochs:  95%|█████████▌| 38/40 [01:18<00:04,  2.01s/it]

Epoch 38/40 - Average Loss: 0.0344986958662048
Validation AUC: 0.6852678571428572


Epochs:  98%|█████████▊| 39/40 [01:20<00:01,  1.99s/it]

Epoch 39/40 - Average Loss: 0.03545105261986868
Validation AUC: 0.7110969387755102


Epochs: 100%|██████████| 40/40 [01:22<00:00,  2.07s/it]

Epoch 40/40 - Average Loss: 0.03145034502570828
Validation AUC: 0.701530612244898





###RMSProp with L2 - ❌

In [29]:
optimizer = optim.RMSprop(resnet50.parameters(), lr=0.001)#, weight_decay=1e-5)

In [30]:
num_epochs = 40
from sklearn.metrics import classification_report, f1_score, roc_auc_score
best_auc = 0.0
# Training loop
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    resnet50.train()  # Set the model to training mode

    epoch_loss = 0.0  # Accumulator for epoch loss
    for images, labels in train_loader:
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device)

        # Forward pass
        outputs = resnet50(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate the loss
        epoch_loss += loss.item()

    # Compute the average loss for the epoch
    epoch_loss /= len(train_loader)
    tqdm.write(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {epoch_loss}")

    # Validation loop
    resnet50.eval()  # Set the model to evaluation mode

    true_probabilities = []
    predicted_probabilities = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, dtype=torch.float)
            labels = labels.to(device)

            # Forward pass
            outputs = resnet50(images)
            probabilities = torch.softmax(outputs, dim=1)

            # Collect predicted probabilities and true labels
            predicted_probabilities.extend(probabilities[:, 1].cpu().numpy())  # Assuming binary classification
            true_probabilities.extend(labels.cpu().numpy())

    # Compute the AUC
    auc = roc_auc_score(true_probabilities, predicted_probabilities)
    print(f"Validation AUC: {auc}")

    # Check if current accuracy is better than the previous best accuracy
    if auc > best_auc:
        best_auc = auc
        torch.save(resnet50.state_dict(), f"resnet50_{best_auc:.4f}.pth")  # Save the model with AUC in the filename
        print("Saved")

Epochs:   0%|          | 0/40 [00:01<?, ?it/s]

Epoch 1/40 - Average Loss: 1.472257886495855
Validation AUC: 0.5770089285714286


Epochs:   2%|▎         | 1/40 [00:02<01:38,  2.52s/it]

Saved


Epochs:   2%|▎         | 1/40 [00:04<01:38,  2.52s/it]

Epoch 2/40 - Average Loss: 0.9978978137175242
Validation AUC: 0.6248405612244898


Epochs:   5%|▌         | 2/40 [00:05<01:37,  2.57s/it]

Saved


Epochs:   5%|▌         | 2/40 [00:06<01:37,  2.57s/it]

Epoch 3/40 - Average Loss: 0.9055871533022987
Validation AUC: 0.6511479591836735


Epochs:   8%|▊         | 3/40 [00:07<01:39,  2.68s/it]

Saved


Epochs:   8%|▊         | 3/40 [00:09<01:39,  2.68s/it]

Epoch 4/40 - Average Loss: 0.8619385030534532
Validation AUC: 0.6972257653061226


Epochs:  10%|█         | 4/40 [00:12<02:03,  3.43s/it]

Saved


Epochs:  12%|█▎        | 5/40 [00:14<01:42,  2.91s/it]

Epoch 5/40 - Average Loss: 0.7597777860032188
Validation AUC: 0.6865433673469389


Epochs:  15%|█▌        | 6/40 [00:16<01:27,  2.57s/it]

Epoch 6/40 - Average Loss: 0.5871662646532059
Validation AUC: 0.6747448979591837


Epochs:  18%|█▊        | 7/40 [00:18<01:17,  2.35s/it]

Epoch 7/40 - Average Loss: 0.5269962598880132
Validation AUC: 0.6594387755102041


Epochs:  20%|██        | 8/40 [00:20<01:10,  2.22s/it]

Epoch 8/40 - Average Loss: 0.38209772523906493
Validation AUC: 0.6403061224489796


Epochs:  22%|██▎       | 9/40 [00:22<01:06,  2.15s/it]

Epoch 9/40 - Average Loss: 0.30661532531181973
Validation AUC: 0.6734693877551019


Epochs:  25%|██▌       | 10/40 [00:24<01:02,  2.08s/it]

Epoch 10/40 - Average Loss: 0.30995765493975747
Validation AUC: 0.6594387755102041


Epochs:  28%|██▊       | 11/40 [00:26<00:58,  2.03s/it]

Epoch 11/40 - Average Loss: 0.24201300367712975
Validation AUC: 0.6498724489795918


Epochs:  30%|███       | 12/40 [00:28<00:57,  2.05s/it]

Epoch 12/40 - Average Loss: 0.20965741698940596
Validation AUC: 0.6278698979591837


Epochs:  32%|███▎      | 13/40 [00:30<00:54,  2.00s/it]

Epoch 13/40 - Average Loss: 0.2438213974237442
Validation AUC: 0.6160714285714285


Epochs:  35%|███▌      | 14/40 [00:31<00:51,  1.97s/it]

Epoch 14/40 - Average Loss: 0.2551862278746234
Validation AUC: 0.6599170918367347


Epochs:  38%|███▊      | 15/40 [00:33<00:49,  1.97s/it]

Epoch 15/40 - Average Loss: 0.18595228602902758
Validation AUC: 0.6857461734693877


Epochs:  40%|████      | 16/40 [00:35<00:47,  1.98s/it]

Epoch 16/40 - Average Loss: 0.15396875308619606
Validation AUC: 0.6771364795918366


Epochs:  42%|████▎     | 17/40 [00:37<00:45,  1.97s/it]

Epoch 17/40 - Average Loss: 0.1877460450761848
Validation AUC: 0.6753826530612245


Epochs:  45%|████▌     | 18/40 [00:39<00:42,  1.95s/it]

Epoch 18/40 - Average Loss: 0.2114092753165298
Validation AUC: 0.6524234693877551


Epochs:  48%|████▊     | 19/40 [00:41<00:40,  1.94s/it]

Epoch 19/40 - Average Loss: 0.13660173759692246
Validation AUC: 0.6862244897959182


Epochs:  50%|█████     | 20/40 [00:43<00:38,  1.93s/it]

Epoch 20/40 - Average Loss: 0.11920413840562105
Validation AUC: 0.6651785714285714


Epochs:  52%|█████▎    | 21/40 [00:45<00:36,  1.92s/it]

Epoch 21/40 - Average Loss: 0.09702013189800912
Validation AUC: 0.6817602040816325


Epochs:  55%|█████▌    | 22/40 [00:47<00:34,  1.93s/it]

Epoch 22/40 - Average Loss: 0.07217827729052967
Validation AUC: 0.6677295918367346


Epochs:  57%|█████▊    | 23/40 [00:49<00:33,  1.96s/it]

Epoch 23/40 - Average Loss: 0.0730606800255676
Validation AUC: 0.6565688775510203


Epochs:  60%|██████    | 24/40 [00:51<00:31,  1.95s/it]

Epoch 24/40 - Average Loss: 0.07370344116093798
Validation AUC: 0.6865433673469388


Epochs:  62%|██████▎   | 25/40 [00:53<00:29,  1.94s/it]

Epoch 25/40 - Average Loss: 0.10160602365309994
Validation AUC: 0.6619897959183674


Epochs:  65%|██████▌   | 26/40 [00:55<00:26,  1.93s/it]

Epoch 26/40 - Average Loss: 0.07280013426983108
Validation AUC: 0.6760204081632654


Epochs:  68%|██████▊   | 27/40 [00:57<00:24,  1.92s/it]

Epoch 27/40 - Average Loss: 0.07880147810404499
Validation AUC: 0.6913265306122449


Epochs:  70%|███████   | 28/40 [00:59<00:22,  1.91s/it]

Epoch 28/40 - Average Loss: 0.1181589458655152
Validation AUC: 0.6865433673469388


Epochs:  72%|███████▎  | 29/40 [01:00<00:21,  1.92s/it]

Epoch 29/40 - Average Loss: 0.17233905961943996
Validation AUC: 0.6718749999999999


Epochs:  75%|███████▌  | 30/40 [01:02<00:19,  1.94s/it]

Epoch 30/40 - Average Loss: 0.1526820218294031
Validation AUC: 0.6648596938775511


Epochs:  78%|███████▊  | 31/40 [01:04<00:17,  1.94s/it]

Epoch 31/40 - Average Loss: 0.06579494083093272
Validation AUC: 0.677295918367347


Epochs:  80%|████████  | 32/40 [01:06<00:15,  1.93s/it]

Epoch 32/40 - Average Loss: 0.15602991776540875
Validation AUC: 0.6846301020408163


Epochs:  82%|████████▎ | 33/40 [01:08<00:13,  1.92s/it]

Epoch 33/40 - Average Loss: 0.0842705174194028
Validation AUC: 0.6769770408163266


Epochs:  85%|████████▌ | 34/40 [01:10<00:11,  1.92s/it]

Epoch 34/40 - Average Loss: 0.11845841163045003
Validation AUC: 0.6808035714285714


Epochs:  88%|████████▊ | 35/40 [01:12<00:09,  1.91s/it]

Epoch 35/40 - Average Loss: 0.0936420554191702
Validation AUC: 0.650829081632653


Epochs:  90%|█████████ | 36/40 [01:14<00:07,  1.92s/it]

Epoch 36/40 - Average Loss: 0.1728747254754934
Validation AUC: 0.696109693877551


Epochs:  92%|█████████▎| 37/40 [01:16<00:05,  1.95s/it]

Epoch 37/40 - Average Loss: 0.13776122266426682
Validation AUC: 0.6814413265306123


Epochs:  95%|█████████▌| 38/40 [01:18<00:03,  1.95s/it]

Epoch 38/40 - Average Loss: 0.0393452079475133
Validation AUC: 0.6935586734693877


Epochs:  98%|█████████▊| 39/40 [01:20<00:01,  1.93s/it]

Epoch 39/40 - Average Loss: 0.06899645169162089
Validation AUC: 0.6827168367346939


Epochs: 100%|██████████| 40/40 [01:22<00:00,  2.06s/it]

Epoch 40/40 - Average Loss: 0.031419607720130846
Validation AUC: 0.6836734693877551





#Resnet101

In [31]:
# Define the device for training (CPU or GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained ResNet101 model
resnet101 = models.resnet101(weights='ResNet101_Weights.DEFAULT')

# Modify the first layer to accept single-channel grayscale images
resnet101.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Modify the last fully connected layer for binary classification with softmax activation
num_classes = 2  # 2 classes: 1 or 0
resnet101.fc = nn.Sequential(
    nn.Linear(resnet101.fc.in_features, num_classes)
)

Downloading: "https://download.pytorch.org/models/resnet101-cd907fc2.pth" to /root/.cache/torch/hub/checkpoints/resnet101-cd907fc2.pth
100%|██████████| 171M/171M [00:02<00:00, 74.4MB/s]


In [32]:
# Move the model to the device
resnet101 = resnet101.to(device)

# Define the loss function (criterion)
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.Adam(resnet101.parameters(), lr=0.001)

num_epochs = 40

In [33]:
from sklearn.metrics import classification_report, f1_score, roc_auc_score
best_auc = 0.0
# Training loop
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    resnet101.train()  # Set the model to training mode

    epoch_loss = 0.0  # Accumulator for epoch loss
    for images, labels in train_loader:
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device)

        # Forward pass
        outputs = resnet101(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate the loss
        epoch_loss += loss.item()

    # Compute the average loss for the epoch
    epoch_loss /= len(train_loader)
    tqdm.write(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {epoch_loss}")

    # Validation loop
    resnet101.eval()  # Set the model to evaluation mode

    true_probabilities = []
    predicted_probabilities = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, dtype=torch.float)
            labels = labels.to(device)

            # Forward pass
            outputs = resnet101(images)
            probabilities = torch.softmax(outputs, dim=1)

            # Collect predicted probabilities and true labels
            predicted_probabilities.extend(probabilities[:, 1].cpu().numpy())  # Assuming binary classification
            true_probabilities.extend(labels.cpu().numpy())

    # Compute the AUC
    auc = roc_auc_score(true_probabilities, predicted_probabilities)
    print(f"Validation AUC: {auc}")

    # Check if current accuracy is better than the previous best accuracy
    if auc > best_auc:
        best_auc = auc
        torch.save(resnet101.state_dict(), f"resnet101_{best_auc:.4f}.pth")  # Save the model with AUC in the filename
        print("Saved")

Epochs:   0%|          | 0/40 [00:02<?, ?it/s]

Epoch 1/40 - Average Loss: 0.7082766029569838
Validation AUC: 0.7133290816326531


Epochs:   2%|▎         | 1/40 [00:02<01:39,  2.56s/it]

Saved


Epochs:   5%|▌         | 2/40 [00:04<01:21,  2.15s/it]

Epoch 2/40 - Average Loss: 0.6015282703770531
Validation AUC: 0.7005739795918368


Epochs:   8%|▊         | 3/40 [00:06<01:16,  2.08s/it]

Epoch 3/40 - Average Loss: 0.5215223696496751
Validation AUC: 0.6916454081632654


Epochs:   8%|▊         | 3/40 [00:08<01:16,  2.08s/it]

Epoch 4/40 - Average Loss: 0.3742721875508626
Validation AUC: 0.7793367346938775


Epochs:  10%|█         | 4/40 [00:09<01:22,  2.30s/it]

Saved


Epochs:  12%|█▎        | 5/40 [00:10<01:15,  2.14s/it]

Epoch 5/40 - Average Loss: 0.39737808373239303
Validation AUC: 0.7551020408163265


Epochs:  12%|█▎        | 5/40 [00:12<01:15,  2.14s/it]

Epoch 6/40 - Average Loss: 0.3028826481766171
Validation AUC: 0.7930484693877551


Epochs:  15%|█▌        | 6/40 [00:13<01:17,  2.28s/it]

Saved


Epochs:  18%|█▊        | 7/40 [00:15<01:10,  2.15s/it]

Epoch 7/40 - Average Loss: 0.2057160445385509
Validation AUC: 0.7388392857142857


Epochs:  18%|█▊        | 7/40 [00:17<01:10,  2.15s/it]

Epoch 8/40 - Average Loss: 0.25109411403536797
Validation AUC: 0.8191964285714286


Epochs:  20%|██        | 8/40 [00:17<01:09,  2.17s/it]

Saved


Epochs:  22%|██▎       | 9/40 [00:19<01:04,  2.09s/it]

Epoch 9/40 - Average Loss: 0.22016756898827022
Validation AUC: 0.7126913265306122


Epochs:  25%|██▌       | 10/40 [00:21<01:02,  2.09s/it]

Epoch 10/40 - Average Loss: 0.17017855164077547
Validation AUC: 0.8045280612244898


Epochs:  28%|██▊       | 11/40 [00:23<00:59,  2.05s/it]

Epoch 11/40 - Average Loss: 0.13826865620083278
Validation AUC: 0.782844387755102


Epochs:  30%|███       | 12/40 [00:25<00:55,  1.98s/it]

Epoch 12/40 - Average Loss: 0.26378602037827176
Validation AUC: 0.757015306122449


Epochs:  32%|███▎      | 13/40 [00:27<00:52,  1.93s/it]

Epoch 13/40 - Average Loss: 0.19135851247443092
Validation AUC: 0.7346938775510203


Epochs:  35%|███▌      | 14/40 [00:29<00:49,  1.91s/it]

Epoch 14/40 - Average Loss: 0.20683971544106802
Validation AUC: 0.7751913265306123


Epochs:  35%|███▌      | 14/40 [00:30<00:49,  1.91s/it]

Epoch 15/40 - Average Loss: 0.1623950410220358
Validation AUC: 0.871811224489796


Epochs:  38%|███▊      | 15/40 [00:31<00:49,  2.00s/it]

Saved


Epochs:  40%|████      | 16/40 [00:33<00:47,  1.99s/it]

Epoch 16/40 - Average Loss: 0.10728854929200476
Validation AUC: 0.8376913265306123


Epochs:  42%|████▎     | 17/40 [00:35<00:46,  2.02s/it]

Epoch 17/40 - Average Loss: 0.09089065001656611
Validation AUC: 0.8600127551020408


Epochs:  42%|████▎     | 17/40 [00:37<00:46,  2.02s/it]

Epoch 18/40 - Average Loss: 0.06019701777646939
Validation AUC: 0.8765943877551021


Epochs:  45%|████▌     | 18/40 [00:37<00:46,  2.12s/it]

Saved


Epochs:  48%|████▊     | 19/40 [00:39<00:42,  2.04s/it]

Epoch 19/40 - Average Loss: 0.07094873524167472
Validation AUC: 0.8708545918367346


Epochs:  50%|█████     | 20/40 [00:41<00:39,  1.98s/it]

Epoch 20/40 - Average Loss: 0.056279734113357134
Validation AUC: 0.8756377551020407


Epochs:  50%|█████     | 20/40 [00:43<00:39,  1.98s/it]

Epoch 21/40 - Average Loss: 0.036683099667748645
Validation AUC: 0.8985969387755103


Epochs:  52%|█████▎    | 21/40 [00:43<00:38,  2.05s/it]

Saved


Epochs:  55%|█████▌    | 22/40 [00:45<00:37,  2.06s/it]

Epoch 22/40 - Average Loss: 0.03509842235750208
Validation AUC: 0.8861607142857143


Epochs:  57%|█████▊    | 23/40 [00:47<00:34,  2.05s/it]

Epoch 23/40 - Average Loss: 0.02864111275024091
Validation AUC: 0.8494897959183673


Epochs:  60%|██████    | 24/40 [00:49<00:32,  2.04s/it]

Epoch 24/40 - Average Loss: 0.09246004587556753
Validation AUC: 0.7219387755102041


Epochs:  62%|██████▎   | 25/40 [00:51<00:29,  1.98s/it]

Epoch 25/40 - Average Loss: 0.19461039919406176
Validation AUC: 0.7729591836734694


Epochs:  65%|██████▌   | 26/40 [00:53<00:27,  1.93s/it]

Epoch 26/40 - Average Loss: 0.13707284153335625
Validation AUC: 0.8335459183673469


Epochs:  68%|██████▊   | 27/40 [00:55<00:24,  1.90s/it]

Epoch 27/40 - Average Loss: 0.11584677857657273
Validation AUC: 0.8316326530612245


Epochs:  70%|███████   | 28/40 [00:56<00:22,  1.88s/it]

Epoch 28/40 - Average Loss: 0.1201984361331496
Validation AUC: 0.8737244897959184


Epochs:  72%|███████▎  | 29/40 [00:58<00:20,  1.87s/it]

Epoch 29/40 - Average Loss: 0.13636857002145714
Validation AUC: 0.8718112244897959


Epochs:  75%|███████▌  | 30/40 [01:00<00:19,  1.92s/it]

Epoch 30/40 - Average Loss: 0.08816123122556342
Validation AUC: 0.8632015306122449


Epochs:  78%|███████▊  | 31/40 [01:02<00:17,  1.97s/it]

Epoch 31/40 - Average Loss: 0.10650379666023785
Validation AUC: 0.8191964285714284


Epochs:  80%|████████  | 32/40 [01:04<00:15,  1.94s/it]

Epoch 32/40 - Average Loss: 0.12008588424780303
Validation AUC: 0.8266900510204082


Epochs:  82%|████████▎ | 33/40 [01:06<00:13,  1.91s/it]

Epoch 33/40 - Average Loss: 0.0983991301101115
Validation AUC: 0.8571428571428572


Epochs:  85%|████████▌ | 34/40 [01:08<00:11,  1.89s/it]

Epoch 34/40 - Average Loss: 0.09754408598463568
Validation AUC: 0.8252551020408163


Epochs:  88%|████████▊ | 35/40 [01:10<00:09,  1.88s/it]

Epoch 35/40 - Average Loss: 0.04352296213619411
Validation AUC: 0.8137755102040816


Epochs:  90%|█████████ | 36/40 [01:12<00:07,  1.89s/it]

Epoch 36/40 - Average Loss: 0.07633870475304623
Validation AUC: 0.8533163265306122


Epochs:  90%|█████████ | 36/40 [01:14<00:07,  1.89s/it]

Epoch 37/40 - Average Loss: 0.12815161718107346


Epochs:  92%|█████████▎| 37/40 [01:14<00:06,  2.07s/it]

Validation AUC: 0.7949617346938774


Epochs:  95%|█████████▌| 38/40 [01:17<00:04,  2.18s/it]

Epoch 38/40 - Average Loss: 0.12047055529223548
Validation AUC: 0.8478954081632654


Epochs:  98%|█████████▊| 39/40 [01:19<00:02,  2.10s/it]

Epoch 39/40 - Average Loss: 0.11847552667475408
Validation AUC: 0.8128188775510204


Epochs: 100%|██████████| 40/40 [01:20<00:00,  2.02s/it]

Epoch 40/40 - Average Loss: 0.12672050116169783
Validation AUC: 0.7946428571428571



