In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from tensorboardX import SummaryWriter

plt.ion()   # interactive mode
%matplotlib inline

In [2]:
torch.cuda.is_available()

True

In [3]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'path to directory'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=16 if x=='train' else 64,
                                             shuffle=False if x=='val' else True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [4]:
import torch
import torch.nn as nn
from torchvision.models import vgg19_bn, densenet121

# Load the pre-trained VGG19 model from a .pt file
vgg19_model = models.vgg19_bn(pretrained=True)

# Load the pre-trained DenseNet121 model from a .pt file
densenet121_model = models.densenet121(pretrained=True)

# Define the fusion model by concatenating the output features from the two models
class SumModel(nn.Module):
    def __init__(self, vgg19_model, densenet121_model):
        super(SumModel, self).__init__()
        # self.vgg19_model = vgg19_model.features
        # self.densenet121_model = densenet121_model.features
        self.vgg19_model = vgg19_model
        self.densenet121_model =  densenet121_model
        self.fc1 = nn.Sequential(
        nn.Linear(1000,384),
        nn.ReLU(),
        nn.Dropout(0.25),
        nn.Linear(384,96),
        nn.ReLU(),
        nn.Dropout(0.25),
        nn.Linear(96,7),
    ) 
        # self.conv_layer_out = k[0][34]
        self.fc2 = nn.Sequential(
        nn.Linear(2000,384),
        nn.ReLU(),
        nn.Dropout(0.25),
        nn.Linear(384,96),
        nn.ReLU(),
        nn.Dropout(0.25),
        nn.Linear(96,7),
    )  # Add a linear layer to output 7 classes
        
    def forward(self, x):
        vgg19_output = self.vgg19_model(x)
        # print("VGG19 Output Size",vgg19_output.size())
        # print(vgg19_output)
        densenet121_output = self.densenet121_model(x)
        # print("Densenet121 Output Size",densenet121_output.size())
        # print(densenet121_output)
        fusion_output = torch.cat((vgg19_output, densenet121_output), dim=1)
        # print("Fusion Output Size",fusion_output.size())
        out_vgg = self.fc1(vgg19_output)
        out_densenet = self.fc1(densenet121_output)
        out_fusion = self.fc2(fusion_output)
        # out=fusion_output
        # print("Out Size",out.size())
        return out_fusion, out_vgg, out_densenet

In [5]:
from tqdm import tqdm
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

def train_model(model, criterion, optimizer, scheduler, writer=None, num_epochs=30):
    since = time.time()
    
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    fusion_output, vgg, densenet = model(inputs)
                    _, preds = torch.max(fusion_output, 1)
                    
                    beta1=1.0
                    lambda1=0.0
                    lambda2=0.0
                    beta2=1.0
                    lambda3=0.0
                    lambda4=0.0
                    beta3=1.0
                    lambda5=0.0
                    lambda6=0.0
                    
                    l1_regularization_fused = 0
                    l2_regularization_fused = 0
                    for param in model.parameters():
                        l1_regularization_fused += torch.norm(param, p=1)
                        l2_regularization_fused += torch.norm(param, p=2)**2
                    
                    l1_regularization_vgg = 0
                    l2_regularization_vgg = 0
                    for param in model.vgg19_model.parameters():
                        l1_regularization_vgg += torch.norm(param, p=1)
                        l2_regularization_vgg += torch.norm(param, p=2)**2
                    
                    l1_regularization_densenet = 0
                    l2_regularization_densenet = 0
                    for param in model.densenet121_model.parameters():
                        l1_regularization_densenet += torch.norm(param, p=1)
                        l2_regularization_densenet += torch.norm(param, p=2)**2
                    # Compute the fusion loss
                    fusion_loss = beta1 * criterion(fusion_output, labels) + \
                                  lambda1 * l1_regularization_fused + \
                                  lambda2 * l2_regularization_fused + \
                                  beta2 * criterion(vgg, labels) + \
                                  lambda3 * l1_regularization_vgg + \
                                  lambda4 * l2_regularization_vgg + \
                                  beta3 * criterion(densenet, labels) + \
                                  lambda5 * l1_regularization_densenet + \
                                  lambda6 * l2_regularization_densenet
                    
                    # loss = loss1 + loss2 + fusion_loss
                    loss =fusion_loss

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
            
            with SummaryWriter() as writer:
                writer.add_scalar(phase+"/loss", epoch_loss, epoch)
                writer.add_scalar(phase+"/acc", epoch_acc, epoch)


            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), '/home/desktop/saved_modell/trial_fnet.pt')

        print()
        

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    
     #Compute confusion matrix
    y_pred = []
    y_true = []
    with torch.no_grad():
        for inputs, labels in dataloaders['val']:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs[0], 1)
            y_pred.extend(preds.cpu().numpy())
            y_true.extend(labels.cpu().numpy())

#           Convert predicted labels to one-hot encoding
    y_pred_onehot = np.zeros((len(y_pred), len(class_names)))
    y_pred_onehot[np.arange(len(y_pred)), y_pred] = 1

#           Convert ground truth labels to one-hot encoding
    y_true_onehot = np.zeros((len(y_true), len(class_names)))
    y_true_onehot[np.arange(len(y_true)), y_true] = 1

#           Compute confusion matrix
    conf_mat = confusion_matrix(y_true_onehot.argmax(axis=1), y_pred_onehot.argmax(axis=1))
    print(classification_report(y_true_onehot, y_pred_onehot, target_names=class_names))
    print(accuracy_score(y_true_onehot, y_pred_onehot))

    print('Confusion matrix:')
    print(conf_mat)
          
    return model



In [6]:
model = SumModel(vgg19_model, densenet121_model)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.0005)
scheduler = lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)
model = train_model(model, criterion, optimizer, scheduler, num_epochs=30)



Epoch 0/29
----------


100%|█████████████████████████████████████████| 276/276 [02:07<00:00,  2.16it/s]


train Loss: 4.0045 Acc: 0.5559


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.04it/s]


val Loss: 4.5979 Acc: 0.6823

Epoch 1/29
----------


100%|█████████████████████████████████████████| 276/276 [02:07<00:00,  2.16it/s]


train Loss: 2.5931 Acc: 0.7671


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.02it/s]


val Loss: 2.1404 Acc: 0.7699

Epoch 2/29
----------


100%|█████████████████████████████████████████| 276/276 [02:09<00:00,  2.14it/s]


train Loss: 2.0391 Acc: 0.8306


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.05it/s]


val Loss: 2.0997 Acc: 0.7943

Epoch 3/29
----------


100%|█████████████████████████████████████████| 276/276 [02:08<00:00,  2.15it/s]


train Loss: 1.8421 Acc: 0.8594


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.94it/s]


val Loss: 1.7728 Acc: 0.8289

Epoch 4/29
----------


100%|█████████████████████████████████████████| 276/276 [02:10<00:00,  2.11it/s]


train Loss: 1.6150 Acc: 0.8846


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.03it/s]


val Loss: 1.2609 Acc: 0.9185

Epoch 5/29
----------


100%|█████████████████████████████████████████| 276/276 [02:09<00:00,  2.13it/s]


train Loss: 1.0514 Acc: 0.9472


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.96it/s]


val Loss: 0.7959 Acc: 0.9613

Epoch 6/29
----------


100%|█████████████████████████████████████████| 276/276 [02:09<00:00,  2.12it/s]


train Loss: 0.9236 Acc: 0.9594


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.09it/s]


val Loss: 0.7626 Acc: 0.9796

Epoch 7/29
----------


100%|█████████████████████████████████████████| 276/276 [02:11<00:00,  2.10it/s]


train Loss: 0.7904 Acc: 0.9714


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.06it/s]


val Loss: 0.7219 Acc: 0.9776

Epoch 8/29
----------


100%|█████████████████████████████████████████| 276/276 [02:11<00:00,  2.10it/s]


train Loss: 0.7312 Acc: 0.9769


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.03it/s]


val Loss: 0.6464 Acc: 0.9817

Epoch 9/29
----------


100%|█████████████████████████████████████████| 276/276 [02:13<00:00,  2.07it/s]


train Loss: 0.6448 Acc: 0.9803


100%|█████████████████████████████████████████████| 8/8 [00:04<00:00,  1.95it/s]


val Loss: 0.6377 Acc: 0.9796

Epoch 10/29
----------


100%|█████████████████████████████████████████| 276/276 [02:12<00:00,  2.08it/s]


train Loss: 0.6049 Acc: 0.9816


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.10it/s]


val Loss: 0.6422 Acc: 0.9796

Epoch 11/29
----------


100%|█████████████████████████████████████████| 276/276 [02:11<00:00,  2.11it/s]


train Loss: 0.5511 Acc: 0.9866


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.07it/s]


val Loss: 0.5794 Acc: 0.9817

Epoch 12/29
----------


100%|█████████████████████████████████████████| 276/276 [02:10<00:00,  2.11it/s]


train Loss: 0.5359 Acc: 0.9875


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.07it/s]


val Loss: 0.5523 Acc: 0.9878

Epoch 13/29
----------


100%|█████████████████████████████████████████| 276/276 [02:08<00:00,  2.15it/s]


train Loss: 0.4990 Acc: 0.9893


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.07it/s]


val Loss: 0.5414 Acc: 0.9898

Epoch 14/29
----------


100%|█████████████████████████████████████████| 276/276 [02:09<00:00,  2.13it/s]


train Loss: 0.4985 Acc: 0.9907


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.06it/s]


val Loss: 0.5491 Acc: 0.9837

Epoch 15/29
----------


100%|█████████████████████████████████████████| 276/276 [02:08<00:00,  2.15it/s]


train Loss: 0.4913 Acc: 0.9907


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.06it/s]


val Loss: 0.5406 Acc: 0.9857

Epoch 16/29
----------


100%|█████████████████████████████████████████| 276/276 [02:08<00:00,  2.15it/s]


train Loss: 0.4645 Acc: 0.9918


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.08it/s]


val Loss: 0.5334 Acc: 0.9878

Epoch 17/29
----------


100%|█████████████████████████████████████████| 276/276 [02:08<00:00,  2.16it/s]


train Loss: 0.4640 Acc: 0.9918


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.07it/s]


val Loss: 0.5402 Acc: 0.9857

Epoch 18/29
----------


100%|█████████████████████████████████████████| 276/276 [02:10<00:00,  2.12it/s]


train Loss: 0.4792 Acc: 0.9927


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.05it/s]


val Loss: 0.5288 Acc: 0.9878

Epoch 19/29
----------


100%|█████████████████████████████████████████| 276/276 [02:09<00:00,  2.13it/s]


train Loss: 0.4662 Acc: 0.9916


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.04it/s]


val Loss: 0.5267 Acc: 0.9857

Epoch 20/29
----------


100%|█████████████████████████████████████████| 276/276 [02:08<00:00,  2.14it/s]


train Loss: 0.4648 Acc: 0.9925


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.06it/s]


val Loss: 0.5231 Acc: 0.9898

Epoch 21/29
----------


100%|█████████████████████████████████████████| 276/276 [02:09<00:00,  2.13it/s]


train Loss: 0.4670 Acc: 0.9930


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.10it/s]


val Loss: 0.5559 Acc: 0.9837

Epoch 22/29
----------


100%|█████████████████████████████████████████| 276/276 [02:09<00:00,  2.13it/s]


train Loss: 0.4619 Acc: 0.9914


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.04it/s]


val Loss: 0.5282 Acc: 0.9857

Epoch 23/29
----------


100%|█████████████████████████████████████████| 276/276 [02:08<00:00,  2.14it/s]


train Loss: 0.4608 Acc: 0.9927


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.07it/s]


val Loss: 0.5278 Acc: 0.9857

Epoch 24/29
----------


100%|█████████████████████████████████████████| 276/276 [02:11<00:00,  2.11it/s]


train Loss: 0.4742 Acc: 0.9925


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.08it/s]


val Loss: 0.5351 Acc: 0.9837

Epoch 25/29
----------


100%|█████████████████████████████████████████| 276/276 [02:07<00:00,  2.16it/s]


train Loss: 0.4715 Acc: 0.9909


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.10it/s]


val Loss: 0.5321 Acc: 0.9837

Epoch 26/29
----------


100%|█████████████████████████████████████████| 276/276 [02:09<00:00,  2.13it/s]


train Loss: 0.4500 Acc: 0.9943


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.04it/s]


val Loss: 0.5237 Acc: 0.9898

Epoch 27/29
----------


100%|█████████████████████████████████████████| 276/276 [02:09<00:00,  2.13it/s]


train Loss: 0.4720 Acc: 0.9912


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.03it/s]


val Loss: 0.5328 Acc: 0.9857

Epoch 28/29
----------


100%|█████████████████████████████████████████| 276/276 [02:09<00:00,  2.13it/s]


train Loss: 0.4724 Acc: 0.9909


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.07it/s]


val Loss: 0.5301 Acc: 0.9898

Epoch 29/29
----------


100%|█████████████████████████████████████████| 276/276 [02:09<00:00,  2.13it/s]


train Loss: 0.4757 Acc: 0.9909


100%|█████████████████████████████████████████████| 8/8 [00:03<00:00,  2.05it/s]

val Loss: 0.5458 Acc: 0.9837

Training complete in 66m 58s
Best val Acc: 0.989817





              precision    recall  f1-score   support

      afraid       1.00      0.94      0.97        70
       angry       1.00      1.00      1.00        70
   disgusted       0.97      1.00      0.99        70
       happy       1.00      1.00      1.00        71
     neutral       1.00      1.00      1.00        70
         sad       1.00      0.99      0.99        70
   surprised       0.96      1.00      0.98        70

   micro avg       0.99      0.99      0.99       491
   macro avg       0.99      0.99      0.99       491
weighted avg       0.99      0.99      0.99       491
 samples avg       0.99      0.99      0.99       491

0.9898167006109979
Confusion matrix:
[[66  0  1  0  0  0  3]
 [ 0 70  0  0  0  0  0]
 [ 0  0 70  0  0  0  0]
 [ 0  0  0 71  0  0  0]
 [ 0  0  0  0 70  0  0]
 [ 0  0  1  0  0 69  0]
 [ 0  0  0  0  0  0 70]]
