In [4]:
import os 
import matplotlib.pyplot as plt
from PIL import Image
import torch 
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torchvision.models as models
from tqdm import tqdm
import pandas as pd
import numpy as np

## Creation of Train and Validation set

In [5]:
df = pd.read_csv('metadata.csv')

In [6]:
paths = df.groupby(["lat","lon","plume"]).count()[['path']].sort_values("path", ascending = False)

In [7]:
paths['lists'] = np.nan

In [8]:
temp = []

for i,row in enumerate(paths.iterrows()):
  row = row[0]
  temp.append(list(df['path'][(df['lat'] == row[0]) & (df['lon'] == row[1])]))

paths['lists'] = temp

In [9]:
paths = paths.reset_index()

In [10]:
paths

Unnamed: 0,lat,lon,plume,path,lists
0,29.631951,35.952379,no,21,[images/no_plume/20230304_methane_mixing_ratio...
1,32.713854,44.609398,no,19,[images/no_plume/20230219_methane_mixing_ratio...
2,33.990812,39.641866,no,18,[images/no_plume/20230206_methane_mixing_ratio...
3,28.510000,77.442400,yes,17,[images/plume/20230327_methane_mixing_ratio_id...
4,36.596520,38.321405,no,15,[images/no_plume/20230122_methane_mixing_ratio...
...,...,...,...,...,...
96,21.039986,-77.824694,no,1,[images/no_plume/20230227_methane_mixing_ratio...
97,24.907500,67.023000,yes,1,[images/plume/20230404_methane_mixing_ratio_id...
98,23.763333,86.396667,yes,1,[images/plume/20230129_methane_mixing_ratio_id...
99,23.740000,90.595000,yes,1,[images/plume/20230206_methane_mixing_ratio_id...


### Split By Location

In [11]:
val_indexs = list(paths[paths['plume'] == 'yes'].sample(int(len(paths)*0.1)).index) +\
list(paths[paths['plume'] == 'no'].sample(int(len(paths)*0.1)).index)

In [13]:
# creating validition set
import os
import random
import shutil

# set up paths
data_dir = 'images'
val_dir = 'validation'
train_dir='train'

plume_dir = 'plume'
no_plume_dir = 'no_plume'



# create validation directories
os.makedirs(os.path.join(train_dir, plume_dir), exist_ok=True)
os.makedirs(os.path.join(train_dir, no_plume_dir), exist_ok=True)
os.makedirs(os.path.join(val_dir, plume_dir), exist_ok=True)
os.makedirs(os.path.join(val_dir, no_plume_dir), exist_ok=True)

# get filenames from true and false directories
val_plume_files = set([y.strip("images/plume/")+str('.tif') for x in list(paths.lists[paths.index.isin(val_indexs[:10])]) for y in x])
val_no_plume_files = set([y.strip("images/no_plume/")+str('.tif') for x in list(paths.lists[paths.index.isin(val_indexs[10:])]) for y in x])
train_files = paths[~paths.index.isin(val_indexs)]
train_plume_files = set([y.strip("images/plume/")+str('.tif') for x in list(train_files.lists[train_files.plume == "yes"]) for y in x])
train_no_plume_files = set([y.strip("images/no_plume/")+str('.tif') for x in list(train_files.lists[train_files.plume == "no"]) for y in x])

# # move sampled files to validation directories
for filename in val_plume_files:
    shutil.move(os.path.join(data_dir, plume_dir, filename),
                os.path.join( val_dir, plume_dir))
for filename in val_no_plume_files:
    shutil.move(os.path.join(data_dir, no_plume_dir, filename),
                os.path.join(val_dir, no_plume_dir))
for filename in train_plume_files:
    shutil.move(os.path.join(data_dir, plume_dir, filename),
                os.path.join( train_dir, plume_dir))
for filename in train_no_plume_files:
    shutil.move(os.path.join(data_dir, no_plume_dir, filename),
                os.path.join(train_dir, no_plume_dir))

### Split By Having each location in training set

In [14]:
import random

val_index = [] 

for x in list(paths.lists):
  val_index.extend(set(list(random.sample(x, int(len(x)*0.35)))))

In [15]:
L = set([y for x in paths.lists for y in x])

for item in val_index:
    if item in L:
        L.remove(item)

In [81]:
len(L)

319

In [82]:
len(set([y for x in paths.lists for y in x]))

428

In [83]:
len(val_index)

109

In [16]:
# creating validition set
import os
import random
import shutil

%cd '/content/drive/Shareddrives/QB Hacakthon'

# set up paths
data_dir = 'images'
val_dir = 'validation_each_location'
train_dir='train_each_location'

plume_dir = 'plume'
no_plume_dir = 'no_plume'

val_plume_files = []
val_no_plume_files = []
train_plume_files = []
train_no_plume_files = []

# create validation directories
os.makedirs(os.path.join(train_dir, plume_dir), exist_ok=True)
os.makedirs(os.path.join(train_dir, no_plume_dir), exist_ok=True)
os.makedirs(os.path.join(val_dir, plume_dir), exist_ok=True)
os.makedirs(os.path.join(val_dir, no_plume_dir), exist_ok=True)

# get filenames from true and false directories
for val in val_index:
  if 'no_plume' in val:
    val_no_plume_files.append(val)
  else:
    val_plume_files.append(val)

for val in L:
  if 'no_plume' in val:
    train_no_plume_files.append(val)
  else:
    train_plume_files.append(val)

val_plume_files = set([y.strip("images/plume/")+str('.tif') for y in val_plume_files])
val_no_plume_files = set([y.strip("images/no_plume/")+str('.tif') for y in val_no_plume_files])
train_plume_files = set([y.strip("images/plume/")+str('.tif') for y in train_plume_files])
train_no_plume_files = set([y.strip("images/no_plume/")+str('.tif') for y in train_no_plume_files])

print(len(val_plume_files))
print(len(val_no_plume_files))
print(len(train_plume_files))
print(len(train_no_plume_files))

# # move sampled files to validation directories
for filename in val_plume_files:
    shutil.move(os.path.join(data_dir, plume_dir, filename),
                os.path.join( val_dir, plume_dir))
for filename in val_no_plume_files:
    shutil.move(os.path.join(data_dir, no_plume_dir, filename),
                os.path.join(val_dir, no_plume_dir))
for filename in train_plume_files:
    shutil.move(os.path.join(data_dir, plume_dir, filename),
                os.path.join( train_dir, plume_dir))
for filename in train_no_plume_files:
    shutil.move(os.path.join(data_dir, no_plume_dir, filename),
                os.path.join(train_dir, no_plume_dir))

[WinError 3] The system cannot find the path specified: "'/content/drive/Shareddrives/QB Hacakthon'"
c:\Users\CompuTop\Desktop\Hackathon QB\methane-leak-detection-in-satellite-imagery\models
48
61
166
153


## Dataset and Dataloader

In [16]:
import torch
import matplotlib.pyplot as plt
from torchvision import transforms

class CustomDataset(Dataset):
    def __init__(self, root_dir, augment=False):
        self.root_dir = root_dir
        self.augment = augment
        self.image_list = []
        self.labels = []
        self.augmented_images = []

        classes = os.listdir(root_dir)
        for class_name in classes:
            class_dir = os.path.join(root_dir, class_name)
            if os.path.isdir(class_dir):
                image_names = os.listdir(class_dir)
                tiff_image_names = [img_name for img_name in image_names if img_name.lower().endswith('.tiff') or img_name.lower().endswith('.tif')]
                self.image_list.extend([os.path.join(class_dir, img_name) for img_name in tiff_image_names])
                self.labels.extend([1 if class_name == 'plume' else 0] * len(tiff_image_names))

                if augment:
                    augmented_images = []
                    for img_name in image_names:
                        img_path = os.path.join(class_dir, img_name)
                        image = Image.open(img_path)
                        image = image.resize((64, 64))  # Resize to 64x64

                        # Apply augmentation transformations
                        transform = transforms.Compose([
                            transforms.RandomHorizontalFlip(),
                            transforms.RandomRotation(10),
                            transforms.ToTensor()
                        ])
                        augmented_image = transform(image)
                        augmented_images.append(augmented_image)

                    self.augmented_images.extend(augmented_images)

        if augment:
            self.image_list += self.augmented_images
            self.labels += self.labels * len(self.augmented_images)

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = self.image_list[idx]
        label = self.labels[idx]

        if isinstance(image, str):
            image = Image.open(image)
            image = image.resize((64, 64))  # Resize to 64x64

            transform = transforms.Compose([
                transforms.ToTensor()
            ])
            image = transform(image)

        return image, label


In [4]:
gen = torch.Generator()
gen.manual_seed(0)

root_dir = "/content/drive/Shareddrives/QB Hacakthon/train_each_location"
val_dir = "/content/drive/Shareddrives/QB Hacakthon/validation_each_location"

train_dataset = CustomDataset(root_dir, augment=True)
val_dataset = CustomDataset(val_dir, augment=False)


# Define batch size for training and validation
batch_size = 32

# Create the train DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Create the validation DataLoader
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Verify the train and validation data loaders
for images, labels in train_loader:
    print(f"Train batch - Images: {images.shape}, Labels: {labels.shape}")
    break

for images, labels in val_loader:
    print(f"Validation batch - Images: {images.shape}, Labels: {labels.shape}")
    break

Train batch - Images: torch.Size([32, 1, 64, 64]), Labels: torch.Size([32])
Validation batch - Images: torch.Size([32, 1, 64, 64]), Labels: torch.Size([32])


## Resnet101

In [None]:
# Define the device for training (CPU or GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained ResNet101 model
resnet = models.resnet101(weights='ResNet101_Weights.IMAGENET1K_V1')

# Modify the first layer to accept single-channel grayscale images
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Modify the last fully connected layer for binary classification with softmax activation
num_classes = 2  # 2 classes: 1 or 0
resnet.fc = nn.Sequential(
    nn.Linear(resnet.fc.in_features, num_classes)
)

Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:02<00:00, 80.1MB/s]


In [None]:
# Move the model to the device
resnet = resnet.to(device)

# Define the loss function (criterion)
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.Adam(resnet.parameters(), lr=0.001)

num_epochs = 20

In [None]:
# Training loop
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    resnet.train()  # Set the model to training mode
    
    epoch_loss = 0.0  # Accumulator for epoch loss
    for images, labels in train_loader:
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device)
        
        # Forward pass
        outputs = resnet(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Accumulate the loss
        epoch_loss += loss.item()
    
    # Compute the average loss for the epoch
    epoch_loss /= len(train_loader)
#     tqdm.write(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {epoch_loss}")

    # Validation loop
    resnet.eval()  # Set the model to evaluation mode
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, dtype=torch.float)
            labels = labels.to(device)

            # Forward pass
            outputs = resnet(images)
            _, predicted = torch.max(outputs.data, 1)

            # Calculate accuracy
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()

    accuracy = total_correct / total_samples
    print(f"Validation Accuracy: {accuracy}")


Epochs:   5%|▌         | 1/20 [00:32<10:14, 32.33s/it]

Validation Accuracy: 0.2702702702702703


Epochs:  10%|█         | 2/20 [00:35<04:28, 14.94s/it]

Validation Accuracy: 0.5225225225225225


Epochs:  15%|█▌        | 3/20 [00:38<02:44,  9.67s/it]

Validation Accuracy: 0.6126126126126126


Epochs:  20%|██        | 4/20 [00:41<01:51,  6.98s/it]

Validation Accuracy: 0.6756756756756757


Epochs:  25%|██▌       | 5/20 [00:44<01:21,  5.45s/it]

Validation Accuracy: 0.7657657657657657


Epochs:  30%|███       | 6/20 [00:46<01:03,  4.50s/it]

Validation Accuracy: 0.43243243243243246


Epochs:  35%|███▌      | 7/20 [00:49<00:50,  3.91s/it]

Validation Accuracy: 0.6216216216216216


Epochs:  40%|████      | 8/20 [00:52<00:42,  3.51s/it]

Validation Accuracy: 0.46846846846846846


Epochs:  45%|████▌     | 9/20 [00:55<00:36,  3.32s/it]

Validation Accuracy: 0.5855855855855856


Epochs:  50%|█████     | 10/20 [00:57<00:31,  3.17s/it]

Validation Accuracy: 0.5315315315315315


Epochs:  55%|█████▌    | 11/20 [01:00<00:27,  3.03s/it]

Validation Accuracy: 0.38738738738738737


Epochs:  60%|██████    | 12/20 [01:03<00:23,  2.92s/it]

Validation Accuracy: 0.7027027027027027


Epochs:  65%|██████▌   | 13/20 [01:05<00:20,  2.87s/it]

Validation Accuracy: 0.4774774774774775


Epochs:  70%|███████   | 14/20 [01:08<00:17,  2.90s/it]

Validation Accuracy: 0.5765765765765766


Epochs:  75%|███████▌  | 15/20 [01:11<00:14,  2.84s/it]

Validation Accuracy: 0.6846846846846847


Epochs:  80%|████████  | 16/20 [01:14<00:11,  2.79s/it]

Validation Accuracy: 0.6846846846846847


Epochs:  85%|████████▌ | 17/20 [01:16<00:08,  2.75s/it]

Validation Accuracy: 0.6486486486486487


Epochs:  90%|█████████ | 18/20 [01:19<00:05,  2.77s/it]

Validation Accuracy: 0.6306306306306306


Epochs:  95%|█████████▌| 19/20 [01:22<00:02,  2.80s/it]

Validation Accuracy: 0.4864864864864865


Epochs: 100%|██████████| 20/20 [01:25<00:00,  4.27s/it]

Validation Accuracy: 0.6756756756756757





In [None]:
model_path = 'resnet_model.pth'

# Save the model
torch.save(resnet.state_dict(), model_path)

## Wide_ResNet50

In [5]:
# Define the device for training (CPU or GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained ResNet101 model
resnet = models.wide_resnet50_2(weights='Wide_ResNet50_2_Weights.IMAGENET1K_V1')

# Modify the first layer to accept single-channel grayscale images
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Modify the last fully connected layer for binary classification with softmax activation
num_classes = 2  # 2 classes: 1 or 0
resnet.fc = nn.Sequential(
    nn.Linear(resnet.fc.in_features, num_classes)
)

Downloading: "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" to /root/.cache/torch/hub/checkpoints/wide_resnet50_2-95faca4d.pth
100%|██████████| 132M/132M [00:01<00:00, 81.4MB/s]


In [6]:
max_lr = 1e-2
num_epochs = 20

# Move the model to the device
resnet = resnet.to(device)

# Define the loss function (criterion)
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.AdamW(resnet.parameters(), lr=max_lr)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=num_epochs,
                                               steps_per_epoch=len(train_loader))

In [7]:
# Retrieve the current learning rate of the optimizer
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [8]:
from sklearn.metrics import classification_report, f1_score, roc_auc_score
best_accuracy = 0.0
lrs = []

# Training loop
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    resnet.train()  # Set the model to training mode
    
    epoch_loss = 0.0  # Accumulator for epoch loss
    for images, labels in train_loader:
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device)
        
        # Forward pass
        outputs = resnet(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Step the learning rate
        lrs.append(get_lr(optimizer))
        scheduler.step() 
        
        # Accumulate the loss
        epoch_loss += loss.item()
    
    # Compute the average loss for the epoch
    epoch_loss /= len(train_loader)

    # Validation loop
    resnet.eval()  # Set the model to evaluation mode
    total_correct = 0
    total_samples = 0
    predicted_labels = []
    true_labels = []
    true_probabilities = []
    predicted_probabilities = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, dtype=torch.float)
            labels = labels.to(device)

            # Forward pass
            outputs = resnet(images)
            _, predicted = torch.max(outputs.data, 1)
            probabilities = torch.softmax(outputs, dim=1)

            # Calculate accuracy
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()

            # Collect predicted and true labels for classification report
            predicted_labels.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

            # Collect predicted probabilities and true labels
            predicted_probabilities.extend(probabilities[:, 1].cpu().numpy())  # Assuming binary classification
            true_probabilities.extend(labels.cpu().numpy())

    accuracy = total_correct / total_samples
    print(f"Validation Accuracy: {accuracy}")

    # Compute F1 score and classification report
    f1 = f1_score(true_labels, predicted_labels, average='macro')
    classification_rep = classification_report(true_labels, predicted_labels)
    print(f"Validation F1 Score: {f1}")
    print("Classification Report:")
    print(classification_rep)

    # Compute the AUC
    auc = roc_auc_score(true_probabilities, predicted_probabilities)
    print(f"Validation AUC: {auc}")
    
    if accuracy > 0.8 and accuracy > best_accuracy:
      print('-----------------------------------------------------------')
      model_path = 'Wide_ResNet50_Weights'+str(epoch)+'.pth'
      best_accuracy = accuracy

      # Save the model
      torch.save(resnet.state_dict(), model_path)

Epochs:   5%|▌         | 1/20 [00:38<12:16, 38.77s/it]

Validation Accuracy: 0.7889908256880734
Validation F1 Score: 0.7800298324120383
Classification Report:
              precision    recall  f1-score   support

           0       0.77      0.89      0.82        61
           1       0.82      0.67      0.74        48

    accuracy                           0.79       109
   macro avg       0.80      0.78      0.78       109
weighted avg       0.79      0.79      0.79       109

Validation AUC: 0.8333333333333334


Epochs:  10%|█         | 2/20 [00:41<05:20, 17.81s/it]

Validation Accuracy: 0.6146788990825688
Validation F1 Score: 0.6143867924528301
Classification Report:
              precision    recall  f1-score   support

           0       0.69      0.57      0.62        61
           1       0.55      0.67      0.60        48

    accuracy                           0.61       109
   macro avg       0.62      0.62      0.61       109
weighted avg       0.63      0.61      0.62       109

Validation AUC: 0.6256830601092895


Epochs:  15%|█▌        | 3/20 [00:44<03:07, 11.01s/it]

Validation Accuracy: 0.7706422018348624
Validation F1 Score: 0.7642121657869689
Classification Report:
              precision    recall  f1-score   support

           0       0.77      0.84      0.80        61
           1       0.77      0.69      0.73        48

    accuracy                           0.77       109
   macro avg       0.77      0.76      0.76       109
weighted avg       0.77      0.77      0.77       109

Validation AUC: 0.826844262295082


Epochs:  20%|██        | 4/20 [00:47<02:05,  7.82s/it]

Validation Accuracy: 0.7247706422018348
Validation F1 Score: 0.7228813559322034
Classification Report:
              precision    recall  f1-score   support

           0       0.77      0.72      0.75        61
           1       0.67      0.73      0.70        48

    accuracy                           0.72       109
   macro avg       0.72      0.73      0.72       109
weighted avg       0.73      0.72      0.73       109

Validation AUC: 0.7745901639344261


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epochs:  25%|██▌       | 5/20 [00:50<01:30,  6.06s/it]

Validation Accuracy: 0.44036697247706424
Validation F1 Score: 0.3057324840764331
Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        61
           1       0.44      1.00      0.61        48

    accuracy                           0.44       109
   macro avg       0.22      0.50      0.31       109
weighted avg       0.19      0.44      0.27       109

Validation AUC: 0.5


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epochs:  30%|███       | 6/20 [00:54<01:15,  5.36s/it]

Validation Accuracy: 0.44036697247706424
Validation F1 Score: 0.3057324840764331
Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        61
           1       0.44      1.00      0.61        48

    accuracy                           0.44       109
   macro avg       0.22      0.50      0.31       109
weighted avg       0.19      0.44      0.27       109

Validation AUC: 0.5


Epochs:  35%|███▌      | 7/20 [00:58<01:01,  4.75s/it]

Validation Accuracy: 0.4954128440366973
Validation F1 Score: 0.441025641025641
Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.16      0.27        61
           1       0.46      0.92      0.62        48

    accuracy                           0.50       109
   macro avg       0.59      0.54      0.44       109
weighted avg       0.60      0.50      0.42       109

Validation AUC: 0.3782445355191257


Epochs:  40%|████      | 8/20 [01:01<00:50,  4.17s/it]

Validation Accuracy: 0.6330275229357798
Validation F1 Score: 0.6126510305614783
Classification Report:
              precision    recall  f1-score   support

           0       0.96      0.36      0.52        61
           1       0.55      0.98      0.70        48

    accuracy                           0.63       109
   macro avg       0.75      0.67      0.61       109
weighted avg       0.78      0.63      0.60       109

Validation AUC: 0.8043032786885245


Epochs:  45%|████▌     | 9/20 [01:04<00:41,  3.77s/it]

Validation Accuracy: 0.7247706422018348
Validation F1 Score: 0.7241902834008096
Classification Report:
              precision    recall  f1-score   support

           0       0.86      0.61      0.71        61
           1       0.64      0.88      0.74        48

    accuracy                           0.72       109
   macro avg       0.75      0.74      0.72       109
weighted avg       0.76      0.72      0.72       109

Validation AUC: 0.7592213114754099


Epochs:  50%|█████     | 10/20 [01:07<00:35,  3.53s/it]

Validation Accuracy: 0.5688073394495413
Validation F1 Score: 0.5334668973681814
Classification Report:
              precision    recall  f1-score   support

           0       0.89      0.26      0.41        61
           1       0.51      0.96      0.66        48

    accuracy                           0.57       109
   macro avg       0.70      0.61      0.53       109
weighted avg       0.72      0.57      0.52       109

Validation AUC: 0.7571721311475409
Validation Accuracy: 0.8165137614678899
Validation F1 Score: 0.8138661202185793
Classification Report:
              precision    recall  f1-score   support

           0       0.84      0.84      0.84        61
           1       0.79      0.79      0.79        48

    accuracy                           0.82       109
   macro avg       0.81      0.81      0.81       109
weighted avg       0.82      0.82      0.82       109

Validation AUC: 0.85724043715847
-----------------------------------------------------------


Epochs:  60%|██████    | 12/20 [01:13<00:26,  3.35s/it]

Validation Accuracy: 0.5779816513761468
Validation F1 Score: 0.5458333333333333
Classification Report:
              precision    recall  f1-score   support

           0       0.89      0.28      0.42        61
           1       0.51      0.96      0.67        48

    accuracy                           0.58       109
   macro avg       0.70      0.62      0.55       109
weighted avg       0.73      0.58      0.53       109

Validation AUC: 0.8886612021857924
Validation Accuracy: 0.8256880733944955
Validation F1 Score: 0.8208012459980965
Classification Report:
              precision    recall  f1-score   support

           0       0.82      0.89      0.85        61
           1       0.84      0.75      0.79        48

    accuracy                           0.83       109
   macro avg       0.83      0.82      0.82       109
weighted avg       0.83      0.83      0.82       109

Validation AUC: 0.8838797814207651
-----------------------------------------------------------


Epochs:  65%|██████▌   | 13/20 [01:16<00:23,  3.39s/it]

Validation Accuracy: 0.8348623853211009
Validation F1 Score: 0.8316746739876459
Classification Report:
              precision    recall  f1-score   support

           0       0.84      0.87      0.85        61
           1       0.83      0.79      0.81        48

    accuracy                           0.83       109
   macro avg       0.83      0.83      0.83       109
weighted avg       0.83      0.83      0.83       109

Validation AUC: 0.878756830601093
-----------------------------------------------------------


Epochs:  75%|███████▌  | 15/20 [01:24<00:17,  3.57s/it]

Validation Accuracy: 0.7981651376146789
Validation F1 Score: 0.7960884353741496
Classification Report:
              precision    recall  f1-score   support

           0       0.83      0.80      0.82        61
           1       0.76      0.79      0.78        48

    accuracy                           0.80       109
   macro avg       0.80      0.80      0.80       109
weighted avg       0.80      0.80      0.80       109

Validation AUC: 0.8777322404371585


Epochs:  80%|████████  | 16/20 [01:27<00:13,  3.37s/it]

Validation Accuracy: 0.7431192660550459
Validation F1 Score: 0.7430976430976431
Classification Report:
              precision    recall  f1-score   support

           0       0.84      0.67      0.75        61
           1       0.67      0.83      0.74        48

    accuracy                           0.74       109
   macro avg       0.75      0.75      0.74       109
weighted avg       0.76      0.74      0.74       109

Validation AUC: 0.8633879781420765


Epochs:  85%|████████▌ | 17/20 [01:30<00:09,  3.24s/it]

Validation Accuracy: 0.7339449541284404
Validation F1 Score: 0.7339449541284403
Classification Report:
              precision    recall  f1-score   support

           0       0.83      0.66      0.73        61
           1       0.66      0.83      0.73        48

    accuracy                           0.73       109
   macro avg       0.74      0.74      0.73       109
weighted avg       0.76      0.73      0.73       109

Validation AUC: 0.8551912568306012


Epochs:  90%|█████████ | 18/20 [01:33<00:06,  3.17s/it]

Validation Accuracy: 0.7247706422018348
Validation F1 Score: 0.7247474747474747
Classification Report:
              precision    recall  f1-score   support

           0       0.83      0.64      0.72        61
           1       0.65      0.83      0.73        48

    accuracy                           0.72       109
   macro avg       0.74      0.74      0.72       109
weighted avg       0.75      0.72      0.72       109

Validation AUC: 0.8623633879781422


Epochs:  95%|█████████▌| 19/20 [01:36<00:03,  3.16s/it]

Validation Accuracy: 0.7247706422018348
Validation F1 Score: 0.7247474747474747
Classification Report:
              precision    recall  f1-score   support

           0       0.83      0.64      0.72        61
           1       0.65      0.83      0.73        48

    accuracy                           0.72       109
   macro avg       0.74      0.74      0.72       109
weighted avg       0.75      0.72      0.72       109

Validation AUC: 0.8582650273224043


Epochs: 100%|██████████| 20/20 [01:39<00:00,  4.99s/it]

Validation Accuracy: 0.7339449541284404
Validation F1 Score: 0.7339449541284403
Classification Report:
              precision    recall  f1-score   support

           0       0.83      0.66      0.73        61
           1       0.66      0.83      0.73        48

    accuracy                           0.73       109
   macro avg       0.74      0.74      0.73       109
weighted avg       0.76      0.73      0.73       109

Validation AUC: 0.8534836065573771





In [9]:
resnet

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), strid

In [10]:
resnet.load_state_dict(torch.load('/content/New_split_each_location_Wide_ResNet50.pth'))

resnet = resnet.to(device)

torch.save(resnet.state_dict(), '/content/drive/Shareddrives/QB Hacakthon/models path/Wide_ResNet50_2_Weights10.pth')

## ResNeXt101

In [24]:
# Define the device for training (CPU or GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained ResNet101 model
#resnet = models.resnext101_64x4d(weights='ResNeXt101_64X4D_Weights.IMAGENET1K_V1')
resnet = models.resnet101(weights='ResNet101_Weights.IMAGENET1K_V1')

# Modify the first layer to accept single-channel grayscale images
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Modify the last fully connected layer for binary classification with softmax activation
num_classes = 2  # 2 classes: 1 or 0
resnet.fc = nn.Sequential(
    nn.Linear(resnet.fc.in_features, num_classes)
)

In [25]:
max_lr = 1e-2
num_epochs = 20
weight_decay = 1e-4

# Move the model to the device
resnet = resnet.to(device)

# Define the loss function (criterion)
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.AdamW(resnet.parameters(), lr=max_lr, weight_decay=weight_decay)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=num_epochs,
                                               steps_per_epoch=len(train_loader))

In [22]:
# Retrieve the current learning rate of the optimizer
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [23]:
from sklearn.metrics import classification_report, f1_score, roc_auc_score
best_accuracy = 0.0
lrs = []

# Training loop
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    resnet.train()  # Set the model to training mode
    
    epoch_loss = 0.0  # Accumulator for epoch loss
    for images, labels in train_loader:
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device)
        
        # Forward pass
        outputs = resnet(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Step the learning rate
        lrs.append(get_lr(optimizer))
        scheduler.step() 
        
        # Accumulate the loss
        epoch_loss += loss.item()
    
    # Compute the average loss for the epoch
    epoch_loss /= len(train_loader)

    # Validation loop
    resnet.eval()  # Set the model to evaluation mode
    total_correct = 0
    total_samples = 0
    predicted_labels = []
    true_labels = []
    true_probabilities = []
    predicted_probabilities = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, dtype=torch.float)
            labels = labels.to(device)

            # Forward pass
            outputs = resnet(images)
            _, predicted = torch.max(outputs.data, 1)
            probabilities = torch.softmax(outputs, dim=1)

            # Calculate accuracy
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()

            # Collect predicted and true labels for classification report
            predicted_labels.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

            # Collect predicted probabilities and true labels
            predicted_probabilities.extend(probabilities[:, 1].cpu().numpy())  # Assuming binary classification
            true_probabilities.extend(labels.cpu().numpy())

    accuracy = total_correct / total_samples
    print(f"Validation Accuracy: {accuracy}")

    # Compute F1 score and classification report
    f1 = f1_score(true_labels, predicted_labels, average='macro')
    classification_rep = classification_report(true_labels, predicted_labels)
    print(f"Validation F1 Score: {f1}")
    print("Classification Report:")
    print(classification_rep)

    # Compute the AUC
    auc = roc_auc_score(true_probabilities, predicted_probabilities)
    print(f"Validation AUC: {auc}")
    
    if accuracy > 0.8 and accuracy > best_accuracy:
      print('-----------------------------------------------------------')
      model_path = 'ResNet101_Weights'+str(epoch)+'.pth'
      best_accuracy = accuracy

      # Save the model
      torch.save(resnet.state_dict(), model_path)

Epochs:   5%|▌         | 1/20 [00:03<00:58,  3.06s/it]

Validation Accuracy: 0.44954128440366975
Validation F1 Score: 0.32382133995037216
Classification Report:
              precision    recall  f1-score   support

           0       1.00      0.02      0.03        61
           1       0.44      1.00      0.62        48

    accuracy                           0.45       109
   macro avg       0.72      0.51      0.32       109
weighted avg       0.76      0.45      0.29       109

Validation AUC: 0.5715505464480874


Epochs:  10%|█         | 2/20 [00:05<00:52,  2.91s/it]

Validation Accuracy: 0.46788990825688076
Validation F1 Score: 0.3801960784313725
Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.08      0.15        61
           1       0.45      0.96      0.61        48

    accuracy                           0.47       109
   macro avg       0.58      0.52      0.38       109
weighted avg       0.60      0.47      0.35       109

Validation AUC: 0.7392418032786885


Epochs:  15%|█▌        | 3/20 [00:08<00:48,  2.84s/it]

Validation Accuracy: 0.5779816513761468
Validation F1 Score: 0.5750847457627117
Classification Report:
              precision    recall  f1-score   support

           0       0.69      0.44      0.54        61
           1       0.51      0.75      0.61        48

    accuracy                           0.58       109
   macro avg       0.60      0.60      0.58       109
weighted avg       0.61      0.58      0.57       109

Validation AUC: 0.6618852459016394


Epochs:  20%|██        | 4/20 [00:11<00:45,  2.81s/it]

Validation Accuracy: 0.5137614678899083
Validation F1 Score: 0.47940884923853294
Classification Report:
              precision    recall  f1-score   support

           0       0.70      0.23      0.35        61
           1       0.47      0.88      0.61        48

    accuracy                           0.51       109
   macro avg       0.59      0.55      0.48       109
weighted avg       0.60      0.51      0.46       109

Validation AUC: 0.4649931693989071


Epochs:  25%|██▌       | 5/20 [00:14<00:42,  2.83s/it]

Validation Accuracy: 0.6880733944954128
Validation F1 Score: 0.6802967563837129
Classification Report:
              precision    recall  f1-score   support

           0       0.94      0.48      0.63        61
           1       0.59      0.96      0.73        48

    accuracy                           0.69       109
   macro avg       0.76      0.72      0.68       109
weighted avg       0.78      0.69      0.67       109

Validation AUC: 0.8679986338797814


Epochs:  30%|███       | 6/20 [00:17<00:40,  2.88s/it]

Validation Accuracy: 0.7247706422018348
Validation F1 Score: 0.7067790530846485
Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.87      0.78        61
           1       0.76      0.54      0.63        48

    accuracy                           0.72       109
   macro avg       0.74      0.71      0.71       109
weighted avg       0.73      0.72      0.72       109

Validation AUC: 0.8060109289617486


Epochs:  35%|███▌      | 7/20 [00:20<00:36,  2.84s/it]

Validation Accuracy: 0.7706422018348624
Validation F1 Score: 0.7694000169247694
Classification Report:
              precision    recall  f1-score   support

           0       0.82      0.75      0.79        61
           1       0.72      0.79      0.75        48

    accuracy                           0.77       109
   macro avg       0.77      0.77      0.77       109
weighted avg       0.78      0.77      0.77       109

Validation AUC: 0.8329918032786885


Epochs:  40%|████      | 8/20 [00:22<00:33,  2.83s/it]

Validation Accuracy: 0.6422018348623854
Validation F1 Score: 0.5725490196078431
Classification Report:
              precision    recall  f1-score   support

           0       0.62      0.93      0.75        61
           1       0.76      0.27      0.40        48

    accuracy                           0.64       109
   macro avg       0.69      0.60      0.57       109
weighted avg       0.68      0.64      0.59       109

Validation AUC: 0.83025956284153


Epochs:  45%|████▌     | 9/20 [00:25<00:30,  2.80s/it]

Validation Accuracy: 0.7431192660550459
Validation F1 Score: 0.7429245283018868
Classification Report:
              precision    recall  f1-score   support

           0       0.87      0.64      0.74        61
           1       0.66      0.88      0.75        48

    accuracy                           0.74       109
   macro avg       0.76      0.76      0.74       109
weighted avg       0.77      0.74      0.74       109

Validation AUC: 0.8444330601092898
Validation Accuracy: 0.8532110091743119
Validation F1 Score: 0.8475524475524476
Classification Report:
              precision    recall  f1-score   support

           0       0.83      0.93      0.88        61
           1       0.90      0.75      0.82        48

    accuracy                           0.85       109
   macro avg       0.86      0.84      0.85       109
weighted avg       0.86      0.85      0.85       109

Validation AUC: 0.855191256830601
-----------------------------------------------------------


Epochs:  55%|█████▌    | 11/20 [00:31<00:26,  2.95s/it]

Validation Accuracy: 0.7247706422018348
Validation F1 Score: 0.7119450317124736
Classification Report:
              precision    recall  f1-score   support

           0       0.72      0.84      0.77        61
           1       0.74      0.58      0.65        48

    accuracy                           0.72       109
   macro avg       0.73      0.71      0.71       109
weighted avg       0.73      0.72      0.72       109

Validation AUC: 0.7862021857923498


Epochs:  60%|██████    | 12/20 [00:34<00:23,  2.90s/it]

Validation Accuracy: 0.7247706422018348
Validation F1 Score: 0.7119450317124736
Classification Report:
              precision    recall  f1-score   support

           0       0.72      0.84      0.77        61
           1       0.74      0.58      0.65        48

    accuracy                           0.72       109
   macro avg       0.73      0.71      0.71       109
weighted avg       0.73      0.72      0.72       109

Validation AUC: 0.8480191256830601


Epochs:  65%|██████▌   | 13/20 [00:37<00:20,  2.87s/it]

Validation Accuracy: 0.8073394495412844
Validation F1 Score: 0.807079646017699
Classification Report:
              precision    recall  f1-score   support

           0       0.88      0.75      0.81        61
           1       0.74      0.88      0.80        48

    accuracy                           0.81       109
   macro avg       0.81      0.81      0.81       109
weighted avg       0.82      0.81      0.81       109

Validation AUC: 0.8859289617486339


Epochs:  70%|███████   | 14/20 [00:40<00:17,  2.85s/it]

Validation Accuracy: 0.7889908256880734
Validation F1 Score: 0.7889908256880733
Classification Report:
              precision    recall  f1-score   support

           0       0.90      0.70      0.79        61
           1       0.70      0.90      0.79        48

    accuracy                           0.79       109
   macro avg       0.80      0.80      0.79       109
weighted avg       0.81      0.79      0.79       109

Validation AUC: 0.8931010928961749


Epochs:  75%|███████▌  | 15/20 [00:43<00:14,  2.90s/it]

Validation Accuracy: 0.7706422018348624
Validation F1 Score: 0.7705649574808453
Classification Report:
              precision    recall  f1-score   support

           0       0.89      0.67      0.77        61
           1       0.68      0.90      0.77        48

    accuracy                           0.77       109
   macro avg       0.79      0.78      0.77       109
weighted avg       0.80      0.77      0.77       109

Validation AUC: 0.8828551912568307


Epochs:  80%|████████  | 16/20 [00:46<00:11,  2.86s/it]

Validation Accuracy: 0.7981651376146789
Validation F1 Score: 0.798012129380054
Classification Report:
              precision    recall  f1-score   support

           0       0.88      0.74      0.80        61
           1       0.72      0.88      0.79        48

    accuracy                           0.80       109
   macro avg       0.80      0.81      0.80       109
weighted avg       0.81      0.80      0.80       109

Validation AUC: 0.8854166666666666


Epochs:  85%|████████▌ | 17/20 [00:48<00:08,  2.83s/it]

Validation Accuracy: 0.7889908256880734
Validation F1 Score: 0.7887062789717658
Classification Report:
              precision    recall  f1-score   support

           0       0.87      0.74      0.80        61
           1       0.72      0.85      0.78        48

    accuracy                           0.79       109
   macro avg       0.79      0.80      0.79       109
weighted avg       0.80      0.79      0.79       109

Validation AUC: 0.8886612021857923


Epochs:  90%|█████████ | 18/20 [00:51<00:05,  2.81s/it]

Validation Accuracy: 0.7889908256880734
Validation F1 Score: 0.7887062789717658
Classification Report:
              precision    recall  f1-score   support

           0       0.87      0.74      0.80        61
           1       0.72      0.85      0.78        48

    accuracy                           0.79       109
   macro avg       0.79      0.80      0.79       109
weighted avg       0.80      0.79      0.79       109

Validation AUC: 0.886441256830601


Epochs:  95%|█████████▌| 19/20 [00:54<00:02,  2.80s/it]

Validation Accuracy: 0.7798165137614679
Validation F1 Score: 0.7796495956873315
Classification Report:
              precision    recall  f1-score   support

           0       0.86      0.72      0.79        61
           1       0.71      0.85      0.77        48

    accuracy                           0.78       109
   macro avg       0.78      0.79      0.78       109
weighted avg       0.79      0.78      0.78       109

Validation AUC: 0.8872950819672132


Epochs: 100%|██████████| 20/20 [00:57<00:00,  2.86s/it]

Validation Accuracy: 0.7798165137614679
Validation F1 Score: 0.7796495956873315
Classification Report:
              precision    recall  f1-score   support

           0       0.86      0.72      0.79        61
           1       0.71      0.85      0.77        48

    accuracy                           0.78       109
   macro avg       0.78      0.79      0.78       109
weighted avg       0.79      0.78      0.78       109

Validation AUC: 0.8852459016393442





In [13]:
resnet.load_state_dict(torch.load('/content/ResNeXt101_Weights11.pth'))

resnet = resnet.to(device)

torch.save(resnet.state_dict(), '/content/drive/Shareddrives/QB Hacakthon/models path/ResNeXt101_Weights11.pth')