Train and evaluate a model using the semantically enhanced feature method.
The model uses a ResNet backbone.


The training dataset is a subset of Stanford Dogs that corresponds with the 25 breeds we were able to collect with the Wyze Cam. The folder `train` is the training data and `wyze`, `yt` and `google` are the testing datasets. The data is available at https://drive.google.com/drive/folders/1GbegJxFDZHp0NiN0bq9VngtMXo9Vjaoi?usp=sharing

You will need to modify the paths in the code below.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd /content/drive/MyDrive/Capstone/SEF-master/

In [None]:
!python --version

In [None]:
import os
import sys
import torch
print(torch.__version__)
import pandas as pd
import numpy as np
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, random_split, DataLoader
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
%matplotlib inline
import torch.nn.functional as F
from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
from collections import OrderedDict
import statistics
import torch.utils.model_zoo as model_zoo
# from .utils import load_state_dict_from_url
import pdb
import torch.nn.functional as torchf
# from utils.misc import SoftSigmoid
import pickle as pk
import uuid
import argparse
import torch.optim as opt
from torch.optim import lr_scheduler
import torch.multiprocessing as mlp
import torch.utils.tensorboard as tb
import copy
import time

In [None]:
%load_ext tensorboard

In [None]:
import tensorflow as tf
import datetime

In [None]:
# torch.cuda.empty_cache()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() > 0 else "cpu")
device_name = device.type+':'+str(device.index) if device.type=='cuda' else 'cpu'

In [None]:
progpath = '/content/drive/MyDrive/Capstone/SEF-master'
sys.path.append(progpath)

datasetname = 'stdogs25'
image_size = 448
batchsize = 16
nthreads = 4
lr = 0.001
lmgm = 1
entropy = 1
soft = 0.05
epochs = 50
optmeth = 'sgd'
regmeth = 'cms'


# Number of attentions for different datasets
# nparts = 2 recommended for Stanford Dogs
nparts = 2


# 'resnet50attention' for SEF, 'resnet50maxent' for ResNet with MaxEnt, 'resnet50vanilla' for the vanilla ResNet
networkname = 'resnet18attention'
if networkname.find('attention') > -1:  # SEF based on ResNet
    attention_flag = True
    maxent_flag = False
else:                                   # The vanilla ResNet
    lmgm=entropy=soft=0
    nparts=1
    attention_flag = False
    maxent_flag = False


# Displaying logs
timeflag = time.strftime("%d-%b-%Y-%H:%M")
# writer = tb.SummaryWriter(log_dir='./runs/'+datasetname+'/'+networkname+time.strftime("%d-%b-%Y"))
log_items = r'{}-net{}-att{}-lmgm{}-entropy{}-soft{}-lr{}-imgsz{}-bsz{}'.format(
    datasetname, int(networkname[6:8]), nparts, lmgm, entropy, soft, lr, image_size, batchsize)
writer = tb.SummaryWriter(comment='-'+log_items)
logfile = open('./runs/'+log_items+'.txt', 'w')


###### MODEL NAME ######
modelname = log_items + f'-{epochs}epoch' + '.model'
# modelname = log_items + f'-{epochs}epoch' + '-aug5' + '.model'
# modelname = log_items + '-aug' + '.model'


# Model zoo and dataset path
datapath = '/content/drive/MyDrive/Capstone/' # path to dataset
modelzoopath = '/content/drive/MyDrive/Capstone/SEF-master/' # path to SEF master folder
sys.path.append(modelzoopath)
datasetpath = '/content/drive/MyDrive/Capstone/stdogs25'
modelpath = '/content/drive/MyDrive/Capstone/SEF-master/models'
resultpath = '/content/drive/MyDrive/Capstone/SEF-master/runs'

In [None]:
# Transforms for different subsets of stdogs25
# wyze, google and yt are our test sets
data_transform = {
    'train': transforms.Compose([
        transforms.Resize((600,600)),
        transforms.RandomCrop((448, 448)),
        transforms.Resize((image_size,image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])          
    ]),
    'wyze': transforms.Compose([
        transforms.Resize((image_size,image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'yt': transforms.Compose([
        transforms.Resize((image_size,image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'google': transforms.Compose([
        transforms.Resize((image_size,image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [None]:
#################################################################################
###### Select the test set for training, can reselect for evaluation later ######
#################################################################################
testset = 'wyze'
# testset = 'google'
# testset = 'yt'

# ORIGINAL 
# Organizing datasets
datasplits = {x: datasets.ImageFolder(os.path.join(datasetpath, x), data_transform[x])
              for x in ['train', testset]}

# Preparing dataloaders for datasets
dataloader = {x: torch.utils.data.DataLoader(datasplits[x], batch_size=batchsize, shuffle=True, num_workers=nthreads)
              for x in ['train', testset]}

class_names = datasplits['train'].classes
num_classes = len(class_names)
print(len(dataloader))

In [None]:
eps = torch.finfo().eps
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
           'wide_resnet50_2', 'wide_resnet101_2']


# Download the .pth file for the model you are going to be training and place in the SEF master folder
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

## Model architecture

In [None]:
class LocalMaxGlobalMin(nn.Module):

    def __init__(self, rho, nchannels, nparts=1, device='cpu'):
        super(LocalMaxGlobalMin, self).__init__()
        self.nparts = nparts
        self.device = device
        self.nchannels = nchannels
        self.rho = rho      

        
        nlocal_channels_norm = nchannels // self.nparts
        reminder = nchannels % self.nparts
        nlocal_channels_last = nlocal_channels_norm
        if reminder != 0:
            nlocal_channels_last = nlocal_channels_norm + reminder
        
        # seps records the indices partitioning feature channels into separate parts
        seps = []
        sep_node = 0
        for i in range(self.nparts):
            if i != self.nparts-1:
                sep_node += nlocal_channels_norm                
                #seps.append(sep_node)
            else:
                sep_node += nlocal_channels_last                
            seps.append(sep_node)
        self.seps = seps
        


    def forward(self, x):  
        x = x.pow(2)
        intra_x = []
        inter_x = []
        for i in range(self.nparts):
            if i == 0:        
                intra_x.append((1 - x[:, :self.seps[i], :self.seps[i]]).mean()) 
            else:              
                intra_x.append((1 - x[:, self.seps[i-1]:self.seps[i], self.seps[i-1]:self.seps[i]]).mean())
                inter_x.append(x[:, self.seps[i-1]:self.seps[i], :self.seps[i-1]].mean())
        
        loss = self.rho * 0.5 * (sum(intra_x) / self.nparts + sum(inter_x) / (self.nparts*(self.nparts-1)/2)) 
                 

        return loss
        

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

       
        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

In [None]:
class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, nparts=0, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None, attention=False, device='cpu'):
        super(ResNet, self).__init__()

        self.attention = attention
        self.device = device
        self.nparts = nparts

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))

        if self.attention:            
            nfeatures = 512 * block.expansion            
            nlocal_channels_norm = nfeatures // self.nparts
            reminder = nfeatures % self.nparts
            nlocal_channels_last = nlocal_channels_norm
            if reminder != 0:
                nlocal_channels_last = nlocal_channels_norm + reminder
            fc_list = []
            separations = []
            sep_node = 0
            for i in range(self.nparts):
                if i != self.nparts-1:
                    sep_node += nlocal_channels_norm
                    fc_list.append(nn.Linear(nlocal_channels_norm, num_classes))
                    #separations.append(sep_node)
                else:
                    sep_node += nlocal_channels_last
                    fc_list.append(nn.Linear(nlocal_channels_last, num_classes))
                separations.append(sep_node)
            self.fclocal = nn.Sequential(*fc_list)
            self.separations = separations 
            self.fc = nn.Linear(512*block.expansion, num_classes) 

        else:            
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
  
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x) 
        x = self.layer4(x)

        if self.attention:

            nsamples, nchannels, height, width = x.shape
        
            xview = x.view(nsamples, nchannels, -1)
            xnorm = xview.div(xview.norm(dim=-1, keepdim=True)+eps)
            xcosin = torch.bmm(xnorm, xnorm.transpose(-1, -2))                               
            

            attention_scores = []
            for i in range(self.nparts):
                if i == 0:
                    xx = x[:, :self.separations[i]]
                else:
                    xx = x[:, self.separations[i-1]:self.separations[i]]
                xx_pool = self.avgpool(xx).flatten(1)
                attention_scores.append(self.fclocal[i](xx_pool))
            xlocal = torch.stack(attention_scores, dim=0)

            xmaps = x.clone().detach()
            
            # for global
            xpool = self.avgpool(x)
            xpool = torch.flatten(xpool, 1)
            xglobal = self.fc(xpool)

            
            return xglobal, xlocal, xcosin, xmaps
        else:
            # for original resnet outputs
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

            return x

In [None]:
def _resnet(arch, block, layers, pretrained, progress, model_dir=None, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls[arch], model_dir=model_dir))
        state_dict = torch.hub.load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def resnet18(pretrained=False, progress=True, model_dir=None, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, model_dir,
                   **kwargs)


def resnet34(pretrained=False, progress=True, **kwargs):
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)


def resnet50(pretrained=False, progress=True, model_dir=None, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, model_dir,
                   **kwargs)

## Define your model

In [None]:
model = resnet18(pretrained=False, model_dir=modelzoopath, nparts=nparts, num_classes=num_classes, attention=attention_flag, device=device)
state_dict_path = os.path.join(modelzoopath, 'resnet18-5c106cde.pth')

state_params = torch.load(state_dict_path)

# pop redundant params from laoded states
state_params.pop('fc.weight')
state_params.pop('fc.bias')

# modify output layer
in_channels = model.fc.in_features
new_fc = nn.Linear(in_channels, num_classes, bias=True)
model.fc = new_fc

# initializing model using pretrained params except the modified layers
model.load_state_dict(state_params, strict=False)
 

# tensorboard writer
images, _ = next(iter(dataloader[testset]))
grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid)
writer.add_graph(model, images)

# to gpu if available
model.to(device)   # model.cuda(device)

In [None]:
cls_loss = nn.CrossEntropyLoss()

# Semantic group loss
lmgm_loss = LocalMaxGlobalMin(rho=lmgm, nchannels=512*4, nparts=nparts, device=device)

criterion = [cls_loss, lmgm_loss]

optimizer = opt.SGD(model.parameters(), lr=lr, momentum=0.9)

# Optimization scheduler
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)  ## CHANGED STEP_SIZE FROM 20 to 10

In [None]:
softmax = nn.Softmax(dim=-1)
logsoftmax = nn.LogSoftmax(dim=-1)
kldiv = nn.KLDivLoss(reduction='batchmean')

In [None]:
from torchsummary import summary
summary(model, (3, 448, 448))

## Train function

In [None]:
def train(model, dataloader, criterion, optimizer, scheduler, datasetname=None, isckpt=False, epochs=5, networkname=None, writer=None, maxent_flag=False, device='cpu', **penalty):
    
    output_log_file = penalty['logfile']
    nparts = model.nparts
    attention_flag = model.attention
    
    if isinstance(dataloader, dict):
        dataset_sizes = {x: len(dataloader[x].dataset) for x in dataloader.keys()}
        print(dataset_sizes)
    else:
        dataset_size = len(dataloader.dataset)

    if not isinstance(criterion, list):
        criterion = [criterion]

    best_model_params = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    global_step = 0
    global_step_resume = 0
    best_epoch = 0
    best_step = 0
    start_epoch = -1
    

    # if isckpt:
    #     checkpoint = modelserial.loadCheckpoint(datasetname+'-'+networkname)

    #     # records for the stopping epoch
    #     start_epoch = checkpoint['epoch']
    #     global_step_resume = checkpoint['global_step']
    #     model.load_state_dict(checkpoint['state_dict'])

    #     # records for the epoch with the best performance
    #     best_model_params = checkpoint['best_state_dict']
    #     best_acc = checkpoint['best_acc']
    #     best_epoch = checkpoint['best_epoch']
    #     optimizer.param_groups[0]['lr'] = checkpoint['current_lr']

    since = time.time()
    for epoch in range(start_epoch+1, epochs):

        # print to file
        print('Epoch {}/{}'.format(epoch, epochs), file=output_log_file)
        print('-' * 10, file=output_log_file)

        # print to terminal
        print('Epoch {}/{}'.format(epoch, epochs))
        print('-' * 10)


        for phase in ['train', testset]:
            if phase == 'train':
                # scheduler.step()
                model.train()  # Set model to training mode
                global_step = global_step_resume
            else:
                model.eval()   # Set model to evaluate mode
                global_step_resume = global_step

            running_cls_loss = 0.0
            running_reg_loss = 0.0
            running_corrects = 0.0
            running_corrects_parts = [0.0] * nparts
            epoch_acc_parts = [0.0] * nparts


            for inputs, labels in dataloader[phase]:
                inputs = inputs.cuda(device)
                labels = labels.cuda(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == 'train'):

                    if attention_flag:
                        # outputs are logits from linear models
                        xglobal, xlocal, xcosin, xmaps = model(inputs)
                        probs = softmax(xglobal)
                        cls_loss = criterion[0](xglobal, labels)

                        ############################################################## prediction

                        # prediction of every  branch
                        probl, predl, logprobl = [], [], []
                        for i in range(nparts):
                            probl.append(softmax(torch.squeeze(xlocal[i])))
                            predl.append(torch.max(probl[i], 1)[-1])
                            logprobl.append(logsoftmax(torch.squeeze(xlocal[i])))


                        ############################################################### regularization

                        logprobs = logsoftmax(xglobal)
                        entropy_loss = penalty['entropy_weights'] * torch.mul(probs, logprobs).sum().div(inputs.size(0))
                        soft_loss_list = []
                        for i in range(nparts):
                            soft_loss_list.append(torch.mul(torch.neg(probs), logprobl[i]).sum().div(inputs.size(0)))
                        soft_loss = penalty['soft_weights'] * sum(soft_loss_list).div(nparts)

                        # regularization loss
                        lmgm_reg_loss = criterion[1](xcosin)
                        reg_loss = lmgm_reg_loss + entropy_loss + soft_loss


                    else:
                        outputs = model(inputs)
                        probs = softmax(outputs)
                        cls_loss = criterion[0](outputs, labels)
                        if maxent_flag:
                            logprobs = logsoftmax(outputs)
                            reg_loss = torch.mul(probs, logprobs).sum().neg().div(inputs.size(0))
                        else:
                            reg_loss = torch.tensor(0.0)

 
                    _, preds = torch.max(probs, 1)   # the indeices of the largeset value in each row   

                    all_loss = cls_loss + reg_loss
                    
                    if phase == 'train':                       
                        all_loss.backward()
                        optimizer.step()

                # statistics
                running_cls_loss += (cls_loss.item()) * inputs.size(0)
                running_reg_loss += (reg_loss.item()) * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                if attention_flag:
                    for i in range(nparts):
                        running_corrects_parts[i] += torch.sum(predl[i] == labels.data)
                    
                # log variables
                global_step += 1
                if global_step % 100 == 1 and writer is not None and phase is 'train':
                    batch_loss = cls_loss.item() + reg_loss.item() 
                    writer.add_scalar('running loss/running_train_loss', batch_loss, global_step)
                    writer.add_scalar('running loss/running_cls_loss', cls_loss, global_step) 
                    if attention_flag:                     
                        writer.add_scalar('running loss/running_lmgm_reg_loss', lmgm_reg_loss, global_step)  
                        writer.add_scalar('running loss/running_entropy_reg_loss', entropy_loss, global_step)  
                        writer.add_scalar('running loss/running_soft_reg_loss', soft_loss, global_step)  
                    elif maxent_flag:
                        writer.add_scalar('running loss/running_maxent_reg_loss', reg_loss, global_step)  
                    for name, param in model.named_parameters():
                        writer.add_histogram('params_in_running/'+name, param.data.clone().cpu().numpy(), global_step)     # global_step



            ############################################### for each epoch
            
            # epoch loss and accuracy
            epoch_loss = running_cls_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            if attention_flag:
                for i in range(nparts):
                    epoch_acc_parts[i] = running_corrects_parts[i].double() / dataset_sizes[phase]


            # log variables for each epoch
            if writer is not None:
                if phase is 'train':
                    writer.add_scalar('epoch loss/train_epoch_loss', epoch_loss, epoch)        # global_step
                    writer.add_scalar('accuracy/train_epoch_acc', epoch_acc, epoch)          # global_step
                    if attention_flag:
                        for i in range(nparts):
                            writer.add_scalar('accuracy/train_acc_part{}_acc'.format(i), epoch_acc_parts[i], epoch) 
                    for name, param in model.named_parameters():
                        writer.add_histogram('params_in_epoch/'+name, param.data.clone().cpu().numpy(), epoch)     # global_step
                elif phase is testset:
                    writer.add_scalar('epoch loss/eval_epoch_loss', epoch_loss, epoch)         # global_step_resume
                    writer.add_scalar('accuracy/eval_epoch_acc', epoch_acc, epoch)          # global_step_resume
                    if attention_flag:
                        for i in range(nparts):
                            writer.add_scalar('accuracy/eval_acc_part{}_acc'.format(i), epoch_acc_parts[i], epoch) 

            # print to log file
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc), file=output_log_file)
            if phase == 'train': print('current lr: {}'.format(optimizer.param_groups[0]['lr']), file=output_log_file)
            if phase == testset: print('\n', file=output_log_file)

            # print to terminal
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            if phase == 'train': print('current lr: {}'.format(optimizer.param_groups[0]['lr']))

            # deep copy the model
            if phase == testset and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_epoch = epoch
                best_step = global_step_resume
                best_model_params = copy.deepcopy(model.state_dict())

            # if phase == 'wyze' and epoch % 5 == 1:
            #     modelserial.saveCheckpoint({'epoch': epoch,
            #                                 'global_step': global_step,
            #                                 'state_dict': model.state_dict(),
            #                                 'best_epoch': best_epoch,
            #                                 'best_state_dict': best_model_params,
            #                                 'best_acc': best_acc, 
            #                                 'current_lr': optimizer.param_groups[0]['lr']},datasetname+'-'+networkname)
        
        # adjust learning rate after each epoch
        scheduler.step()

        
        print()


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed%60), file=output_log_file)
    print('Best test Acc: {:4f}'.format(best_acc) , file=output_log_file)
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed%60))
    print('Best test Acc: {:4f}'.format(best_acc))


    # recording training params
    rsltparams = dict()
    rsltparams['datasetname'] = datasetname
    rsltparams['nparts'] = model.nparts
    rsltparams['val_acc'] = best_acc.item()
    rsltparams['lmgm'] = criterion[1].rho
    rsltparams['lr'] = optimizer.param_groups[0]['lr']
    rsltparams['best_epoch'] = best_epoch
    rsltparams['best_step'] = best_step
    rsltparams['soft_weights'] = penalty['soft_weights']
    rsltparams['entropy_weights'] = penalty['entropy_weights']
    # rsltparams['attmaps'] = attmaps

    # load best model weights
    model.load_state_dict(best_model_params)
    return model, rsltparams

## Start training a model

In [None]:
isckpt = False  # set True to load learned models from checkpoint, defatult False

indfile = "{}: opt={}, lr={}, lmgm={}, nparts={}, entropy={}, soft={}, epochs={}, imgsz={}, batch_sz={}".format(
    datasetname, optmeth, lr, lmgm, nparts, entropy, soft, epochs, image_size, batchsize)
print("\n{}\n".format(indfile))
print("\n{}\n".format(indfile), file=logfile)


model, train_rsltparams = train(
    model, dataloader, criterion, optimizer, scheduler, 
    datasetname=datasetname, isckpt=isckpt, epochs=epochs, 
    networkname=networkname, writer=writer, device=device, maxent_flag=maxent_flag,
    soft_weights=soft, entropy_weights=entropy, logfile=logfile)


train_rsltparams['imgsz'] = image_size
train_rsltparams['epochs'] = epochs
train_rsltparams['init_lr'] = lr
train_rsltparams['batch_sz'] = batchsize

print('\nBest epoch: {}'.format(train_rsltparams['best_epoch']))
print('\nBest epoch: {}'.format(train_rsltparams['best_epoch']), file=logfile)
print("\n{}\n".format(indfile))
print("\n{}\n".format(indfile), file=logfile)
print('\nWorking on cluster: {}\n'.format(device_name))

logfile.close()

## Save the model

In [None]:
torch.save({'model_params':model.state_dict(), 'train_params':train_rsltparams}, os.path.join(modelpath, modelname))

In [None]:
%tensorboard --logdir=runs

## Evaluation function

In [None]:
def eval(model, dataloader=None, device='cpu', datasetname=None):

    if not datasetname or datasetname not in ['stdogs', 'wyzedogs', 'wyze', 'yt', 'google', 'stdogs25']:
        print("illegal dataset")
        return
    
    attention_flag = model.attention
    model.eval()
    datasize = len(dataloader.dataset)
    running_corrects = 0
    good_data = []
    bad_data = []
    gi = []
    bi = []
    maps = []
    lab = []
    inp = []
    num_label_counts = dict()
    pred_label_counts = dict()
    

    for inputs, labels in dataloader:

        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            if attention_flag:
                outputs, xlocal, xcosin, xmaps = model(inputs)
            else:
                outputs = model(inputs)

        # Save xmaps, inputs and labels in lists
        # Return them with result parameters
        maps.append(xmaps.cpu().numpy())
        lab.append(labels.cpu().numpy())
        inp.append(inputs.cpu().numpy())
                
        probs = softmax(outputs)
        _, preds = torch.max(probs, 1)

        running_corrects += torch.sum(preds == labels.data)

        # record paths and labels
        good_mask = preds == labels.data
        bad_mask = torch.logical_not(good_mask)
        good_index = good_mask.nonzero()
        gi.append(good_index.cpu().numpy())
        bad_index = bad_mask.nonzero()
        bi.append(bad_index.cpu().numpy())
        for idx in good_index:
            good_data.append(inputs[idx].cpu().numpy())
        for idx in bad_index:
            bad_data.append(inputs[idx].cpu().numpy())

    
    acc = torch.div(running_corrects.double(), datasize).item()
    avg_acc = 0.0
    print("General Accuracy: {}".format(acc))


    rsltparams = dict()
    rsltparams['acc'] = acc
    rsltparams['avg_acc'] = avg_acc
    rsltparams['good_data'] = good_data
    rsltparams['bad_data'] = bad_data
    rsltparams['good_index'] = gi
    rsltparams['bad_index'] = bi
    rsltparams['xlocal'] = xlocal
    rsltparams['xcosin'] = xcosin
    rsltparams['xmaps'] = xmaps
    rsltparams['outputs'] = outputs
    rsltparams['maps'] = maps
    rsltparams['labels'] = lab
    rsltparams['inputs'] = inp
    
    return rsltparams

## Evaluate a model

In [None]:
%%time

# Import undistort transform
# Must be used before resize transform because it accepts (1920, 1080) Wyze Cam images
# Only use undistort on wyze set
from undistort import undistort
import cv2
# from utils.mydataloader import DataLoader

eps = torch.finfo().eps
device = torch.device("cuda:0" if torch.cuda.is_available() > 0 else "cpu")
progpath = '/content/drive/MyDrive/Capstone/SEF-master'
sys.path.append(progpath)
datapath = '/content/drive/MyDrive/Capstone/'
modelzoopath = '/content/drive/MyDrive/Capstone/SEF-master/models/'
sys.path.append(os.path.realpath(modelzoopath))
modelpath = os.path.join(progpath, 'models')
resultpath = os.path.join(progpath, 'runs')
datasetname = 'stdogs25'

# modelname = r'stdogs-net50-att1-lmgm0-entropy0-soft0-lr0.01-imgsz448-bsz32.model' # ResNet-50 base model, 120
# modelname = r'stdogs-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz32.model' # ResNet-50 with SEF, 120
# modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz32.model' # ResNet-50 with SEF trained on stdogs25, no augmentation
# modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz32-aug.model' # ResNet-50 with SEF trained on stdogs25, with augmentation
# modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz16-aug.model' # ResNet-50 with SEF trained on stdogs25, with augmentation

# modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz16-1epoch.model' # ResNet-50 with SEF trained on stdogs25 for 1 epoch
# modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz16-5epoch.model' # ResNet-50 with SEF trained on stdogs25 for 5 epochs
# modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz16-10epoch.model' # ResNet-50 with SEF trained on stdogs25 for 10 epochs
modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz16-50epoch.model' # ResNet-50 with SEF trained on stdogs25 for 50 epochs

# modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz16-100epoch-aug5.model' # ResNet-50 with SEF trained on stdogs25 for 100 epochs, augmentation x5, lr scheduler 10 steps


# modelname = r'stdogs25-net18-att2-lmgm1-entropy1-soft0.05-lr0.001-imgsz448-bsz32-50epoch.model' # ResNet-18 with SEF trained on stdogs25, no aug, 0.001 LR
# modelname = r'stdogs25-net18-att1-lmgm0-entropy0-soft0-lr0.001-imgsz448-bsz32-50epoch.model' # ResNet-18 vanilla
# modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.001-imgsz448-bsz16-50epoch.model'
# modelname = r'stdogs25-net50-att1-lmgm0-entropy0-soft0-lr0.001-imgsz448-bsz16-50epoch.model'




load_params = torch.load(os.path.join(modelpath, modelname), map_location=device)
networkname = modelname.split('-')[1]

model_state_dict, train_params = load_params['model_params'], load_params['train_params']

nparts = train_params['nparts']
lmgm = train_params['lmgm']
entropy = train_params['entropy_weights']
soft = train_params['soft_weights']
batchsize = train_params['batch_sz']
imgsz = train_params['imgsz']
lr = train_params['init_lr']
if datasetname == 'stdogs': num_classes = 120
if datasetname == 'stdogs25': num_classes = 25
attention_flag = True if nparts > 1 else False
netframe = 'resnet50' if networkname.find('50') > -1 else 'resnet18'

model = resnet50(pretrained=False, model_dir=modelzoopath, nparts=nparts, num_classes=num_classes, attention=attention_flag, device=device)
model.load_state_dict(model_state_dict, strict=True)
model.to(device)


datasetpath = os.path.join(datapath, datasetname)
datasetpath = os.path.join(datapath, 'stdogs25')

# Transforms for test images
# Undsistort transform only for testing on wyze images
data_transform = {
    'wyze': transforms.Compose([
        transforms.Lambda(undistort),
        transforms.Resize((image_size,image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'yt': transforms.Compose([
        transforms.Resize((image_size,image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'google': transforms.Compose([
        transforms.Resize((image_size,image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# testset = 'wyze'
testset = 'google'
# testset = 'yt'

# Batch size
batch_size = 1

test_transform = data_transform[testset]

testsplit = ImageFolder(os.path.join(datasetpath, testset), data_transform[testset])
# Number of images in testset
N = len(testsplit)
testloader = DataLoader(testsplit, batch_size=batch_size, shuffle=False, num_workers=4)
test_rsltparams = eval(model, testloader, datasetname=datasetname, device=device)

print('General Acc: {}, Class Avg Acc: {}'.format(test_rsltparams['acc'], test_rsltparams['avg_acc']))

In [None]:
# modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz16-1epoch.model'
# load_params = torch.load(os.path.join(modelpath, modelname), map_location=device)
# networkname = modelname.split('-')[1]
# model_state_dict, train_params = load_params['model_params'], load_params['train_params']
# model1 = resnet50(pretrained=False, model_dir=modelzoopath, nparts=nparts, num_classes=num_classes, attention=attention_flag, device=device)
# model1.load_state_dict(model_state_dict, strict=True)
# model1.to(device)


# modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz16-5epoch.model'
# load_params = torch.load(os.path.join(modelpath, modelname), map_location=device)
# networkname = modelname.split('-')[1]
# model_state_dict, train_params = load_params['model_params'], load_params['train_params']
# model5 = resnet50(pretrained=False, model_dir=modelzoopath, nparts=nparts, num_classes=num_classes, attention=attention_flag, device=device)
# model5.load_state_dict(model_state_dict, strict=True)
# model5.to(device)


# modelname = r'stdogs25-net50-att2-lmgm1-entropy1-soft0.05-lr0.01-imgsz448-bsz16-50epoch.model'
# load_params = torch.load(os.path.join(modelpath, modelname), map_location=device)
# networkname = modelname.split('-')[1]
# model_state_dict, train_params = load_params['model_params'], load_params['train_params']
# model50 = resnet50(pretrained=False, model_dir=modelzoopath, nparts=nparts, num_classes=num_classes, attention=attention_flag, device=device)
# model50.load_state_dict(model_state_dict, strict=True)
# model50.to(device)

## Visualize activation maps

In [None]:
import cv2
from PIL import Image

path1 = '/content/drive/MyDrive/Capstone/stdogs25/Images/n02086646-Blenheim_spaniel/n02086646_602.jpg'
path2 = '/content/drive/MyDrive/Capstone/stdogs25/Images/n02099601-golden_retriever/n02099601_304.jpg'
path3 = '/content/drive/MyDrive/Capstone/stdogs25/Images/n02110185-Siberian_husky/n02110185_248.jpg'

transform = transforms.Compose([
        transforms.Resize((448, 448)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = Image.open(path3)
input = transform(img).unsqueeze(0).to(device)


xglobal, xlocal, xcosin, xmaps = model5(input)
data = xmaps.cpu().numpy()

# Get activation maps
att1 = data[0, 0, :, :]
for i in range(1, 1024):
  att1 = np.add(att1, data[0, i, :, :])
# att1 = 255 - att1
att1 = cv2.resize(att1, (448, 448))

att2 = data[0, 1024, :, :]
for i in range(1025, 2048):
  att2 = np.add(att2, data[0, i, :, :])
# att2 = 255 - att2
att2 = cv2.resize(att2, (448, 448))

# Get input image
img = img.resize((448, 448))


fig, ax = plt.subplots(1)
# Plot image
ax.imshow(img)
# Plot pcolormesh of activation map
# p1 = ax.pcolormesh(att1, alpha=0.1)
p2 = ax.pcolormesh(att2, alpha=0.1)
# fig.colorbar(p2, ax=ax)
# plt.savefig(f'husky248_group2_1epoch.png')





# xglobal, xlocal, xcosin, xmaps = model50(input)
# data = xmaps.cpu().numpy()

# # Get activation maps
# att1 = data[0, 0, :, :]
# for i in range(1, 1024):
#   att1 = np.add(att1, data[0, i, :, :])
# att1 = cv2.resize(att1, (448, 448))

# att2 = data[0, 1024, :, :]
# for i in range(1025, 2048):
#   att2 = np.add(att2, data[0, i, :, :])
# att2 = 255 - att2
# att2 = cv2.resize(att2, (448, 448))

# # Get input image
# img = img.resize((448, 448))


# fig, ax = plt.subplots(1)
# # Plot image
# ax.imshow(img)
# # Plot pcolormesh of activation map
# # p1 = ax.pcolormesh(att1, alpha=0.1)
# p2 = ax.pcolormesh(att2, alpha=0.1)
# # fig.colorbar(p2, ax=ax)
# # plt.savefig(f'husky248_group2_50epoch.png')

In [None]:
for i in range(0, 1024):
  # plt.figure(i)
  plt.imshow(data[0, i])

In [None]:
xmaps = test_rsltparams['xmaps']
x = xmaps.view(1, 2048, -1)
print(x.shape)
xl = x.cpu().numpy().squeeze()

xcosin = test_rsltparams['xcosin']
a = xcosin.cpu().numpy().squeeze()

xl_ = np.matmul(a, xl)
xl_ = np.reshape(xl_, (2048, 14, 14))
print(np.shape(xl_))

for i in range(2048):
  # plt.figure(i)
  plt.imshow(xmaps.cpu().numpy().squeeze()[i])