In [1]:
#https://github.com/Cadene/pretrained-models.pytorch

import pretrainedmodels
import torch
from Dataloader import load_and_preprocess_dataset
from PIL import Image
from torchvision import transforms
from torchsummary import summary
import numpy as np

In [2]:
train_imgs, train_probs, train_types, test_imgs, test_probs, test_types = load_and_preprocess_dataset(wire_removal="Gray", augment="All", out_types="All", aug_types = ["Flip"], channels=3)

1508 295 106 715
----- Method:[reduce_dataset], ran in 0.0 Seconds,
1508 295 106 715
----- Method:[remove_cell_wires], ran in 2.1291441917419434 Seconds,
----- Method:[split_t_t_data], ran in 4.297459602355957 Seconds,
1131 221 80 535
377 74 26 180
----- Method:[expand_dataset], ran in 6.4475226402282715 Seconds,
----- Method:[expand_dataset], ran in 0.8624765872955322 Seconds,
----- Method:[shuffle_set], ran in 1.108220100402832 Seconds,
4524 884 320 2140
1508 296 104 720
----- Method:[make_3_channel], ran in 48.84747672080994 Seconds,
----- Method:[make_3_channel], ran in 19.061686992645264 Seconds,
----- Method:[load_and_preprocess_dataset], ran in 130.8902130126953 Seconds,


In [3]:
base_model = pretrainedmodels.vgg19()



In [4]:
# fine tuning
dim_feats = base_model.last_linear.in_features
nb_classes = 4
base_model.last_linear = torch.nn.Linear(dim_feats, nb_classes)

In [5]:
for param in base_model._features:
    param.requires_grad = False

In [6]:
base_model.linear0.requires_grad = True
base_model.relu0.requires_grad = True
base_model.dropout0.requires_grad = True
base_model.linear1.requires_grad = True
base_model.relu1.requires_grad = True
base_model.dropout1.requires_grad = True
base_model.last_linear.requires_grad = True

In [7]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(base_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08)

In [8]:
class AllDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        image = (image * 255).astype(np.uint8)
        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        # image = torch.tensor(image, dtype=torch.float32)

        return image, label
    
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
])

In [9]:
dataset = AllDataset(train_imgs, train_probs, transform)
validation_split = 0.2
dataset_size = len(dataset)
val_size = int(validation_split * dataset_size)
train_size = dataset_size - val_size

# Use random_split to create training and validation datasets
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False)

In [10]:
epochs = 5

for epoch in range(epochs):
    # Training phase
    base_model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for batch_images, batch_labels in train_dataloader:
        
        batch_labels = (batch_labels * (4 - 1)).round().long()
        one_hot_labels = torch.zeros(len(batch_labels), 4)
        one_hot_labels.scatter_(1, batch_labels.view(-1, 1), 1)

        optimizer.zero_grad()
        outputs = base_model(batch_images)
        loss = criterion(outputs, one_hot_labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(f'Batch Loss: {loss.item():.4f}')
        _, predicted = torch.max(outputs, 1)
        total_samples += one_hot_labels.size(0)
        correct_predictions += (predicted == batch_labels).sum().item()

    average_loss = running_loss / len(train_dataloader)
    accuracy = correct_predictions / total_samples
    print(f'Training Epoch [{epoch + 1}/{epochs}], Loss: {average_loss:.4f}, Accuracy: {accuracy:.4f}')

    # Validation phase
    base_model.eval()  # Set the model to evaluation mode
    val_running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for val_images, val_labels in val_dataloader:
            
            val_outputs = base_model(val_images)
            
            val_labels = (val_labels * (4 - 1)).round().long()
            one_hot_labels = torch.zeros(len(val_labels), 4)
            one_hot_labels.scatter_(1, val_labels.view(-1, 1), 1)
            val_loss = criterion(val_outputs, one_hot_labels)

            val_running_loss += val_loss.item()

            _, predicted = torch.max(val_outputs, 1)
            total_samples += val_labels.size(0)
            correct_predictions += (predicted == val_labels).sum().item()
            

    val_average_loss = val_running_loss / len(val_dataloader)
    val_accuracy = correct_predictions / total_samples

    print(f'Validation Epoch [{epoch + 1}/{epochs}], Validation Loss: {average_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')

Training Epoch [1/5], Loss: 0.4692, Accuracy: 0.5692
Validation Epoch [1/5], Validation Loss: 0.4692, Validation Accuracy: 0.5779
Training Epoch [2/5], Loss: 0.4505, Accuracy: 0.5728
Validation Epoch [2/5], Validation Loss: 0.4505, Validation Accuracy: 0.5779
Training Epoch [3/5], Loss: 0.4497, Accuracy: 0.5743
Validation Epoch [3/5], Validation Loss: 0.4497, Validation Accuracy: 0.5779
Training Epoch [4/5], Loss: 0.4489, Accuracy: 0.5743
Validation Epoch [4/5], Validation Loss: 0.4489, Validation Accuracy: 0.5779
Training Epoch [5/5], Loss: 0.4498, Accuracy: 0.5743
Validation Epoch [5/5], Validation Loss: 0.4498, Validation Accuracy: 0.5779


In [11]:
test_dataset = AllDataset(test_imgs, test_probs, transform)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

base_model.eval()

# Initialize variables for accuracy calculation
running_loss = 0.0
correct_predictions = 0
total_samples = 0
all_predictions = torch.Tensor([])
all_labels = torch.Tensor([])

# Disable gradient calculation during testing
with torch.no_grad():

    for test_images, test_labels in test_dataloader:

        outputs = base_model(test_images)
        
        test_labels = (test_labels * (4 - 1)).round().long()
        one_hot_labels = torch.zeros(len(test_labels), 4)
        one_hot_labels.scatter_(1, test_labels.view(-1, 1), 1)

        loss = criterion(outputs, one_hot_labels)

        running_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total_samples += test_labels.size(0)
        correct_predictions += (predicted == test_labels).sum().item()
        all_predictions = torch.cat((all_predictions, predicted))
        all_labels = torch.cat((all_labels, test_labels))

    average_loss = running_loss / len(test_dataloader)
    accuracy = correct_predictions / total_samples

# Calculate overall accuracy
accuracy = correct_predictions / total_samples

print(average_loss)
print(accuracy)

0.4498161711476066
0.573820395738204


In [12]:
from sklearn.metrics import classification_report

print(classification_report(all_predictions, all_labels))

              precision    recall  f1-score   support

         0.0       1.00      0.57      0.73      2628
         1.0       0.00      0.00      0.00         0
         2.0       0.00      0.00      0.00         0
         3.0       0.00      0.00      0.00         0

    accuracy                           0.57      2628
   macro avg       0.25      0.14      0.18      2628
weighted avg       1.00      0.57      0.73      2628



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [13]:
torch.save(base_model.state_dict(), "models/vgg19-bw.pth")

In [15]:
torch.load('models/vgg19-bw.pth')

OrderedDict([('_features.0.weight',
              tensor([[[[-0.0583, -0.0541, -0.0727],
                        [ 0.0105,  0.0403, -0.0027],
                        [ 0.0314,  0.0152,  0.0151]],
              
                       [[ 0.0122,  0.0506, -0.0110],
                        [ 0.1368,  0.2222,  0.1328],
                        [ 0.1152,  0.1955,  0.0873]],
              
                       [[-0.0497,  0.0079, -0.0193],
                        [ 0.0549,  0.1347,  0.0493],
                        [-0.0058,  0.0535, -0.0345]]],
              
              
                      [[[ 0.2648, -0.3006, -0.4968],
                        [ 0.4178, -0.2041, -0.4861],
                        [ 0.5914,  0.4324, -0.1341]],
              
                       [[ 0.2915, -0.3291, -0.4509],
                        [ 0.3824, -0.2872, -0.4925],
                        [ 0.5515,  0.4937, -0.1673]],
              
                       [[ 0.0713, -0.0911, -0.0333],
                    