In [105]:
#https://github.com/Cadene/pretrained-models.pytorch

import pretrainedmodels
import torch
from Dataloader import load_and_preprocess_dataset
from PIL import Image
from torchvision import transforms
from torchsummary import summary
import numpy as np

In [2]:
train_imgs, train_probs, train_types, test_imgs, test_probs, test_types = load_and_preprocess_dataset(wire_removal="Gray", augment="All", out_types="Mono", aug_types = ["Flip", "Rot"], channels=3)

1508 295 106 715
----- Method:[reduce_dataset], ran in 0.16181230545043945 Seconds,
588 117 56 313
----- Method:[remove_cell_wires], ran in 0.4165050983428955 Seconds,
----- Method:[split_t_t_data], ran in 0.3147096633911133 Seconds,
441 88 42 234
147 29 14 79
----- Method:[expand_dataset], ran in 0.5301566123962402 Seconds,
----- Method:[expand_dataset], ran in 0.16232037544250488 Seconds,
----- Method:[shuffle_set], ran in 0.3378615379333496 Seconds,
1764 352 168 936
588 116 56 316
----- Method:[make_3_channel], ran in 2.581594705581665 Seconds,
----- Method:[make_3_channel], ran in 1.7840611934661865 Seconds,
----- Method:[load_and_preprocess_dataset], ran in 29.090898513793945 Seconds,


In [171]:
base_model = pretrainedmodels.vgg19()



VGG(
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (_features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv

In [172]:
for param in base_model._features:
    param.requires_grad = False

In [173]:
base_model.linear0.requires_grad = True
base_model.relu0.requires_grad = True
base_model.dropout0.requires_grad = True
base_model.linear1.requires_grad = True
base_model.relu1.requires_grad = True
base_model.dropout1.requires_grad = True
base_model.last_linear.requires_grad = True

In [174]:
# fine tuning
dim_feats = base_model.last_linear.in_features
nb_classes = 4
base_model.last_linear = torch.nn.Linear(dim_feats, nb_classes)

In [175]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(base_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08)

In [176]:
class AllDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        image = (image * 255).astype(np.uint8)
        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        # image = torch.tensor(image, dtype=torch.float32)

        return image, label

In [177]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
])

dataset = AllDataset(train_imgs, train_probs, transform)
validation_split = 0.2
dataset_size = len(dataset)
val_size = int(validation_split * dataset_size)
train_size = dataset_size - val_size

# Use random_split to create training and validation datasets
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False)

In [193]:
epochs = 2

for epoch in range(epochs):
    # Training phase
    base_model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for batch_images, batch_labels in train_dataloader:
        
        batch_labels = (batch_labels * (4 - 1)).round().long()
        one_hot_labels = torch.zeros(len(batch_labels), 4)
        one_hot_labels.scatter_(1, batch_labels.view(-1, 1), 1)

        optimizer.zero_grad()
        outputs = base_model(batch_images)
        loss = criterion(outputs, one_hot_labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(f'Batch Loss: {loss.item():.4f}')
        _, predicted = torch.max(outputs, 1)
        total_samples += one_hot_labels.size(0)
        correct_predictions += (predicted == batch_labels).sum().item()

    average_loss = running_loss / len(train_dataloader)
    accuracy = correct_predictions / total_samples
    print(f'Training Epoch [{epoch + 1}/{epochs}], Loss: {average_loss:.4f}')

    # Validation phase
    base_model.eval()  # Set the model to evaluation mode
    val_running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for val_images, val_labels in val_dataloader:
            
            val_outputs = base_model(val_images)
            
            val_labels = (val_labels * (4 - 1)).round().long()
            one_hot_labels = torch.zeros(len(val_labels), 4)
            one_hot_labels.scatter_(1, val_labels.view(-1, 1), 1)
            val_loss = criterion(val_outputs, one_hot_labels)

            val_running_loss += val_loss.item()

            _, predicted = torch.max(val_outputs, 1)
            total_samples += val_labels.size(0)
            correct_predictions += (predicted == val_labels).sum().item()

    val_average_loss = val_running_loss / len(val_dataloader)
    val_accuracy = correct_predictions / total_samples

    print(f'Validation Epoch [{epoch + 1}/{epochs}], Validation Loss: {average_loss:.4f}')

In [191]:
test_dataset = AllDataset(test_imgs, test_probs, transform)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

base_model.eval()

# Initialize variables for accuracy calculation
running_loss = 0.0
correct_predictions = 0
total_samples = 0
all_predictions = torch.Tensor([])
all_labels = torch.Tensor([])

# Disable gradient calculation during testing
with torch.no_grad():

    for test_images, test_labels in test_dataloader:

        outputs = base_model(test_images)
        
        test_labels = (test_labels * (4 - 1)).round().long()
        one_hot_labels = torch.zeros(len(test_labels), 4)
        one_hot_labels.scatter_(1, test_labels.view(-1, 1), 1)

        loss = criterion(outputs, one_hot_labels)

        running_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total_samples += test_labels.size(0)
        correct_predictions += (predicted == test_labels).sum().item()
        all_predictions = torch.cat((all_predictions, predicted))
        all_labels = torch.cat((all_labels, test_labels))

    average_loss = running_loss / len(test_dataloader)
    accuracy = correct_predictions / total_samples

# Calculate overall accuracy
accuracy = correct_predictions / total_samples

print(accuracy)

0.5464684014869888


In [192]:
from sklearn.metrics import classification_report

print(classification_report(all_predictions, all_labels))

              precision    recall  f1-score   support

         0.0       1.00      0.55      0.71      1076
         1.0       0.00      0.00      0.00         0
         2.0       0.00      0.00      0.00         0
         3.0       0.00      0.00      0.00         0

    accuracy                           0.55      1076
   macro avg       0.25      0.14      0.18      1076
weighted avg       1.00      0.55      0.71      1076



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [13]:
torch.save(base_model, '../models/vgg19-bw.pth')