<a href="https://colab.research.google.com/github/chandruCS165/DRE_net_for_finding_covid/blob/main/DRE_net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Drive Mount

In [1]:
from google.colab import drive
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [2]:
import torch
devices = [d for d in range(torch.cuda.device_count())]
device_names  = [torch.cuda.get_device_name(d) for d in devices]
print(device_names)

['Tesla K80']


# Configuration

In [3]:
BATCH_SIZE = 8
PROPOSAL_NUM = 6
NO_OF_EPOCHS = 4
CAT_NUM = 6
INPUT_SIZE = (448, 448)     # (448, 448)  # (w, h)
LR = 0.0008                  # 0.002
WD = 1e-4
SAVE_FREQ = 2
# resume = ''#'./20200223_152850/002.ckpt'
test_model = 'model.ckpt'
save_dir = './'
BRIGHTNESS = 0.5
SATURATION = 0.5
MEAN = [0.485, 0.456, 0.406]
SD = [0.229, 0.224, 0.225]
#resume=''

# DataPreprocessing

In [4]:
import numpy as np
import os
from PIL import Image
from torchvision import transforms
import random
random.seed(0)

class DataPreprocessingTrain():
    def __init__(self, root_dir):
        """
        Class Usage(Construtor)
            Used for Data augumentation and data preprocessing of training data

            Parameters:
            root_dir (String): Path of root directory for image dataset

            Returns:
                Returns nothing 

        """
        # image_file_lst stores the image file's path
        self.image_file_lst = []     

        # label_lst stores corresponding labels of image_file_lst                      
        self.label_lst = []

        # This code intialize the image_file_lst and label_lst
        # LABEL = {0:"no_nCoV",1 :"nCoV"}
        # root_dir = input/
        Label = {0:"no_nCoV",1 :"nCoV"}
        for label,image_dir in Label.items():
            image_path = f'{root_dir}/{image_dir}'
            self.image_file_lst += [f'{image_path}/{item}' for item in os.listdir(image_path)]
            self.label_lst += [label]*len(os.listdir(image_path))
        
        # Shuffles the dataset
        temp = list(zip(self.image_file_lst, self.label_lst))
        random.shuffle(temp)
        self.image_file_lst = [item[0] for item in temp]
        self.label_lst = [item[1] for item in temp]
        
    def __getitem__(self, index):
        """
        __getitem__(getter)
            Returns the preprocessed image with corresponding label and raw image 

            Parameters:
                index (String): Index of image and label pair needed

            Returns:
                image (tensor): Return the Preprocessed image in tensor format
                target (int): Returns the actual label of image(tensor)
                img_raw (tensor): Return the raw image in tensor format
        """
        image = np.array(Image.open(self.image_file_lst[index]))
        target = self.label_lst[index]

        #If image is gray scale change into RGB
        if len(image.shape) == 2:
            image = np.stack([image] * 3, 2)
        
        #take the copy of the image to img_raw
        img_raw = image.copy()
        image = Image.fromarray(image, mode='RGB')
        image = transforms.Resize(INPUT_SIZE, transforms.InterpolationMode.BILINEAR)(image)
        
        #Data Augumentation 
        flg_H = 0
        if np.random.randint(2) == 1:
            flg_H = 1
            image = transforms.RandomHorizontalFlip(p=1)(image)
        image = transforms.ColorJitter(brightness=BRIGHTNESS, saturation=SATURATION)(image)
        image = transforms.ToTensor()(image)
        image = transforms.Normalize(MEAN,SD)(image)
        img_raw = Image.fromarray(img_raw, mode='RGB')
        img_raw = transforms.Resize((600, 600), transforms.InterpolationMode.BILINEAR)(img_raw)
        if flg_H == 1:
            img_raw = transforms.RandomHorizontalFlip(p=1)(img_raw)
        img_raw = transforms.ToTensor()(img_raw)
        img_raw = transforms.Normalize(MEAN, SD)(img_raw)


        return image, target, img_raw
    def __len__(self):
        """
        Parameters:
            No Parameters

        Returns:
            (int) : length of the dataset
        """
        return len(self.label_lst)


class DataPreprocessingVal():
    def __init__(self, root_dir):
        """
        Class Usage(Construtor)
            Used for Data augumentation and data preprocessing of validation data and testing data
            
            Parameters:
            root_dir (String): Path of root directory for image dataset

            Returns:
                Returns nothing 

        """
        # image_file_lst stores the image file's path
        self.image_file_lst = []     
        Label = {0:"no_nCoV",1 :"nCoV"}
        # label_lst stores corresponding labels of image_file_lst                      
        self.label_lst = []

        # This code intialize the image_file_lst and label_lst
        # Label = {0:"no_nCoV",1 :"nCoV"}
        for label,image_dir in Label.items():
            image_path = f'{root_dir}/{image_dir}'
            self.image_file_lst += [f'{image_path}/{item}' for item in os.listdir(image_path)]
            self.label_lst += [label]*len(os.listdir(image_path))
        
        # Shuffles the dataset
        temp = list(zip(self.image_file_lst, self.label_lst))
        random.shuffle(temp)
        self.image_file_lst = [item[0] for item in temp]
        self.label_lst = [item[1] for item in temp]
        
    def __getitem__(self, index):
        """
        __getitem__(getter)
            Returns the preprocessed image with corresponding label and raw image 

            Parameters:
                index (String): Index of image and label pair needed

            Returns:
                image (tensor): Return the Preprocessed image in tensor format
                target (int): Returns the actual label of image(tensor)
                img_raw (tensor): Return the raw image in tensor format
        """
        image = np.array(Image.open(self.image_file_lst[index]))
        target = self.label_lst[index]
        #If image is gray scale change into RGB
        if len(image.shape) == 2:
            image = np.stack([image] * 3, 2)
        
        img_raw = image.copy()
        image = Image.fromarray(image, mode='RGB')
        image = transforms.Resize(INPUT_SIZE, transforms.InterpolationMode.BILINEAR)(image)
        
        #Data Augumentation 
        image = transforms.ColorJitter(brightness=BRIGHTNESS, saturation=SATURATION)(image)
        image = transforms.ToTensor()(image)
        image = transforms.Normalize(MEAN,SD)(image)
        img_raw = Image.fromarray(img_raw, mode='RGB')
        img_raw = transforms.Resize((600, 600), transforms.InterpolationMode.BILINEAR)(img_raw)
        img_raw = transforms.ToTensor()(img_raw)
        img_raw = transforms.Normalize(MEAN, SD)(img_raw)


        return image, target, img_raw

    def __len__(self):
        """
        __len__(getter)

            Parameters:
                No Parameters


            Returns:
                (int) : length of the dataset
        """
        return len(self.label_lst)




# ResNet

In [5]:
#
# resnet link : https://pytorch.org/hub/pytorch_vision_resnet/

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        feature_map = x                    #feature1 --> (batchsize,2048,14,14)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = nn.Dropout(p=0.5)(x)
        feature2 = x                    #
        x = self.fc(x)

        return x, feature_map, feature2

def resnet50(pretrained=False, **kwargs):
    """
    Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth'))
    return model

In [6]:
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
m = resnet50(pretrained=True).to(device)
m.avgpool = nn.AdaptiveAvgPool2d(1).to(device)
m.fc = nn.Linear(512 * 4, 2).to(device)
print()
print("Summary of resnet50 for (3,448,448) ")
print()
summary(m, (3, 448, 448))
print()
print("Summary of resnet50 for (3,224,224) ")
print()
summary(m, (3, 224, 224))


Summary of resnet50 for (3,448,448) 

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           9,408
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 112, 112]               0
            Conv2d-5         [-1, 64, 112, 112]           4,096
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
            Conv2d-8         [-1, 64, 112, 112]          36,864
       BatchNorm2d-9         [-1, 64, 112, 112]             128
             ReLU-10         [-1, 64, 112, 112]               0
           Conv2d-11        [-1, 256, 112, 112]          16,384
      BatchNorm2d-12        [-1, 256, 112, 112]             512
           Conv2d-13        [-1, 256, 112, 112]          16,384


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


          Conv2d-104          [-1, 256, 28, 28]         589,824
     BatchNorm2d-105          [-1, 256, 28, 28]             512
            ReLU-106          [-1, 256, 28, 28]               0
          Conv2d-107         [-1, 1024, 28, 28]         262,144
     BatchNorm2d-108         [-1, 1024, 28, 28]           2,048
            ReLU-109         [-1, 1024, 28, 28]               0
      Bottleneck-110         [-1, 1024, 28, 28]               0
          Conv2d-111          [-1, 256, 28, 28]         262,144
     BatchNorm2d-112          [-1, 256, 28, 28]             512
            ReLU-113          [-1, 256, 28, 28]               0
          Conv2d-114          [-1, 256, 28, 28]         589,824
     BatchNorm2d-115          [-1, 256, 28, 28]             512
            ReLU-116          [-1, 256, 28, 28]               0
          Conv2d-117         [-1, 1024, 28, 28]         262,144
     BatchNorm2d-118         [-1, 1024, 28, 28]           2,048
            ReLU-119         [-1, 1024, 

# Anchors

In [7]:
import numpy as np
default_anchors_small = (
    dict(stride=32, size=48),
    dict(stride=64, size=96),
)
default_anchors_large = (
    dict(stride=128, size=192),
)
def intialize_anchor_maps(setting='small'):
    """
    intialize_anchor_maps
    This function is used to set the anchors in default position before training data

    Parameter:
        settings (String): This parameter determines the size of the anchors

    Return: 
        edge_anchors (List): list of the egde anchors with top-left corner(y0,x0) and bottom right corner(y1,x1) (y0, x0, y1, x1)
    """
    if setting == 'small':
        anchors_setting = default_anchors_small
    else:
        anchors_setting = default_anchors_large

    edge_anchors = np.zeros((0, 4), dtype=np.float32)
    input_shape = np.array(INPUT_SIZE, dtype=int)

    for anchor_info in anchors_setting:

        stride = anchor_info['stride']
        size = anchor_info['size']

        output_map_shape = np.ceil(input_shape.astype(np.float32) / stride)
        output_map_shape = output_map_shape.astype(np.int)
        output_shape = tuple(output_map_shape) + (4,)
        start = stride / 2.
        oy = np.arange(start, start + stride * output_shape[0], stride)
        oy = oy.reshape(output_shape[0], 1)
        ox = np.arange(start, start + stride * output_shape[1], stride)
        ox = ox.reshape(1, output_shape[1])
        center_anchor_map = np.zeros(output_shape, dtype=np.float32)
        center_anchor_map[:, :, 0] = oy
        center_anchor_map[:, :, 1] = ox
        edge_anchor_map = np.concatenate((center_anchor_map[..., :2] - size / 2.,
                                            center_anchor_map[..., :2] + size / 2.),
                                            axis=-1)
        edge_anchors = np.concatenate((edge_anchors, edge_anchor_map.reshape(-1, 4)))

    return edge_anchors

#*
def hard_nms(cdds, topk=6, iou_thresh=0.25):

    cdds = cdds.copy()
    indices = np.argsort(cdds[:, 0])
    cdds = cdds[indices]
    cdd_results = []

    res = cdds

    while res.any():
        cdd = res[-1]
        cdd_results.append(cdd)
        if len(cdd_results) == topk:
            return np.array(cdd_results)
        res = res[:-1]

        start_max = np.maximum(res[:, 1:3], cdd[1:3])
        end_min = np.minimum(res[:, 3:5], cdd[3:5])
        lengths = end_min - start_max
        intersec_map = lengths[:, 0] * lengths[:, 1]
        intersec_map[np.logical_or(lengths[:, 0] < 0, lengths[:, 1] < 0)] = 0
        iou_map_cur = intersec_map / ((res[:, 3] - res[:, 1]) * (res[:, 4] - res[:, 2]) + (cdd[3] - cdd[1]) * (
            cdd[4] - cdd[2]) - intersec_map)
        res = res[iou_map_cur <= iou_thresh]

    return np.array(cdd_results)

def get_xy(y0, x0, y1, x1, size=448):
    pad_size = size//2
    
    y0 = np.max([y0, pad_size])
    y0 = np.min([y0, size+pad_size])
    
    x0 = np.max([x0, pad_size])
    x0 = np.min([x0, size+pad_size])
    
    y1 = np.max([y1, pad_size])
    y1 = np.min([y1, size+pad_size])
    
    x1 = np.max([x1, pad_size])
    x1 = np.min([x1, size+pad_size])
    
    return y0, x0, y1, x1

# DRE-net

In [8]:
from torch import nn
import torch
from torch.autograd import Variable
import torchvision
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import numpy as np

class FPN(nn.Module):
    def __init__(self):
        super(FPN, self).__init__()
        self.conv1 = nn.Conv2d(2048, 128, 3, 1, 1)  # 2048 --> 128  , 3 X 3, strides = 1,padding = 1
        self.conv2 = nn.Conv2d(128, 128, 3, 2, 1)   
        self.conv3 = nn.Conv2d(128, 128, 3, 2, 1)
        self.ReLU = nn.ReLU()
        self.order1 = nn.Conv2d(128, 1, 1, 1, 0)
        self.order2 = nn.Conv2d(128, 1, 1, 1, 0)
        self.order3 = nn.Conv2d(128, 1, 1, 1, 0)

    def forward(self, x):
        batch_size = x.size(0)
        x1 = self.ReLU(self.conv1(x))
        x2 = self.ReLU(self.conv2(x1))
        x3 = self.ReLU(self.conv3(x2))
        t1 = self.order1(x1).view(batch_size, -1)
        t2 = self.order2(x2).view(batch_size, -1)
        t = self.order3(x3).view(batch_size, -1)
        return torch.cat((t1, t2), dim=1), t


class DRE_net(nn.Module):
    def __init__(self, topK=6, n_class=2):
        super(DRE_net, self).__init__()
        self.n_class = n_class
        self.resNet = resnet50(pretrained=True)
        self.resNet.avgpool = nn.AdaptiveAvgPool2d(1)
        self.resNet.fc = nn.Linear(512 * 4, self.n_class)
        self.fpn_net = FPN()
        self.topK = topK
        self.mlp = nn.Linear(2048 * (CAT_NUM + 1 + 1), self.n_class)
        self.sub_mlp = nn.Linear(512 * 4, self.n_class)
        
        self.pad_side = 224

        #intializing the edge anchors
        edge_anchors_small= intialize_anchor_maps(setting='small')
        self.edge_anchors_small = (edge_anchors_small + 224).astype(np.int)
        edge_anchors_large= intialize_anchor_maps(setting='large')
        self.edge_anchors_large = (edge_anchors_large + 224).astype(np.int)
        
        
        

    def forward(self, image, img_raw):

        #image = (batch_size,3,448,448)
        #img_raw = (batch_size,3,600,600)
        resnet_out, feature_map, feature = self.resNet(image)   # x, feature1,feature2
        rn_logits = resnet_out
        image_pad = F.pad(image, (self.pad_side, self.pad_side, 
                          self.pad_side, self.pad_side), mode='constant', value=0)
        #pad = 224
        batch = image.size(0)
       
        # region_small = (batch,245,1)
        # region_large = (batch,16,1)
        region_score_small, region_score_large = self.fpn_net(feature_map.detach())
        all_subimg_small = [
            np.concatenate((x.reshape(-1, 1), 
                            self.edge_anchors_small.copy(), 
                            np.arange(0, len(x)).reshape(-1, 1)), axis=1)
            for x in region_score_small.data.cpu().numpy()]
        top_K_subimg_small = [hard_nms(x, topk=self.topK//2, iou_thresh=0.1) for x in all_subimg_small]
        top_K_subimg_small = np.array(top_K_subimg_small)
        top_K_index_small = top_K_subimg_small[:, :, -1].astype(np.int)
        top_K_index_small = torch.from_numpy(top_K_index_small).cuda()
        top_K_prob_small = torch.gather(region_score_small, dim=1, index=top_K_index_small)



        all_subimg_large = [
            np.concatenate((x.reshape(-1, 1), 
                            self.edge_anchors_large.copy(), 
                            np.arange(0, len(x)).reshape(-1, 1)), axis=1)
            for x in region_score_large.data.cpu().numpy()]
        top_K_subimg_large = [hard_nms(x, topk=self.topK//2, iou_thresh=0.1) for x in all_subimg_large]
        top_K_subimg_large = np.array(top_K_subimg_large)
        top_K_index_large = top_K_subimg_large[:, :, -1].astype(np.int)
        top_K_index_large = torch.from_numpy(top_K_index_large).cuda()
        top_K_prob_large = torch.gather(region_score_large, dim=1, index=top_K_index_large)
        
        sub_imgs = torch.zeros([batch, self.topK, 3, 224, 224]).cuda()


        #upSampling the all small and large sub images into 224 X 224
        for i in range(batch):
            for j in range(self.topK//2):
                [y0, x0, y1, x1] = top_K_subimg_small[i][j, 1:5].astype(np.int)
                sub_imgs[i:i + 1, j] = F.interpolate(image_pad[i:i + 1, :, y0:y1, x0:x1], size=(224, 224), mode='bilinear',
                                                      align_corners=True)
                [y0, x0, y1, x1] = top_K_subimg_large[i][j, 1:5].astype(np.int)
                sub_imgs[i:i + 1, j+self.topK//2] = F.interpolate(image_pad[i:i + 1, :, y0:y1, x0:x1], size=(224, 224), mode='bilinear',
                                                      align_corners=True)
                
        sub_imgs = sub_imgs.view(batch * self.topK, 3, 224, 224)
        #sending all subimgs of size 224 X 224 to resnet
        _, _, subimg_features = self.resNet(sub_imgs.detach())
        subimg_feature = subimg_features.view(batch, self.topK, -1)
        subimg_feature = subimg_feature[:, :CAT_NUM, ...].contiguous()
        #changing the dimension into (batch)
        subimg_feature = subimg_feature.view(batch, -1)



        image2 = image.clone()
        for bs in range(batch):
            [y0, x0, y1, x1] = top_K_subimg_large[bs][0, 1:5].astype(np.int)
            y0, x0, y1, x1 = get_xy(y0, x0, y1, x1)
            y0 = np.int((y0 - 224)/448*600)
            x0 = np.int((x0 - 224)/448*600)
            y1 = np.int((y1 - 224)/448*600)
            x1 = np.int((x1 - 224)/448*600)
            image2[bs] = F.interpolate(
                    img_raw[bs:bs + 1, :, y0:y1, x0:x1],
                    size=(448, 448), mode='bilinear', align_corners=True)
        _, _, feature2 = self.resNet(image2.detach()) # 
        
        top_K_index = torch.cat([top_K_index_small, top_K_index_large], 1)
        top_K_prob = torch.cat([top_K_prob_small, top_K_prob_large], 1)
        # mlp_logits have the shape: Batch*200
        mlp_in = torch.cat([subimg_feature, feature, feature2], dim=1)
        mlp_logits = self.mlp(mlp_in)
        # sub_logits have the shape: Batch*topK*200
        sub_logits = self.sub_mlp(subimg_features).view(batch, self.topK, -1)
        return [rn_logits, mlp_logits, sub_logits, 
                top_K_index, top_K_prob]

#*
def list_loss(logits, targets):
    temp = F.log_softmax(logits, -1)
    loss = [-temp[i][targets[i].item()] for i in range(logits.size(0))]
    return torch.stack(loss)

#*
def ranking_loss(score, targets, proposal_num=PROPOSAL_NUM):
    loss = Variable(torch.zeros(1).cuda())
    batch_size = score.size(0)
    for i in range(proposal_num):
        targets_p = (targets > targets[:, i].unsqueeze(1)).type(torch.cuda.FloatTensor)
        pivot = score[:, i].unsqueeze(1)
        loss_p = (1 - pivot + score) * targets_p
        loss_p = torch.sum(F.relu(loss_p))
        loss += loss_p
    return loss / batch_size

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
m = FPN().to(device)
print()
print("Summary of FPNs")
print()
summary(m, (2048,14,14))


Summary of FPNs

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 14, 14]       2,359,424
              ReLU-2          [-1, 128, 14, 14]               0
            Conv2d-3            [-1, 128, 7, 7]         147,584
              ReLU-4            [-1, 128, 7, 7]               0
            Conv2d-5            [-1, 128, 4, 4]         147,584
              ReLU-6            [-1, 128, 4, 4]               0
            Conv2d-7            [-1, 1, 14, 14]             129
            Conv2d-8              [-1, 1, 7, 7]             129
            Conv2d-9              [-1, 1, 4, 4]             129
Total params: 2,654,979
Trainable params: 2,654,979
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.53
Forward/backward pass size (MB): 0.51
Params size (MB): 10.13
Estimated Total Size (MB): 12.17
-------------------

# Main 

In [10]:
import os
os.chdir("/gdrive/MyDrive/Mini-Project/COVID19-CT-main/local_traniner/model")

import torch.utils.data
from torch.nn import DataParallel
import numpy as np

# read dataset

train_path = '../input/train/'
val_path = '../input/val/'
test_path = '../input/test/'
trainset = DataPreprocessingTrain(root_dir=train_path)
valset = DataPreprocessingVal(root_dir=val_path)
testset = DataPreprocessingVal(root_dir=test_path)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2, drop_last=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2, drop_last=False)
valloader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2, drop_last=False)

n_class = 2
print("dataset loading done")

# define model
model = DRE_net(topK=PROPOSAL_NUM, n_class=n_class)
creterion = torch.nn.CrossEntropyLoss()

# define optimizers
rn_parameters = list(model.resNet.parameters())
fpn_parameters = list(model.fpn_net.parameters())
mlp_parameters = list(model.mlp.parameters())
submlp_parameters = list(model.sub_mlp.parameters())

rn_optimizer = torch.optim.SGD(rn_parameters, lr=LR, momentum=0.9, weight_decay=WD)
mlp_optimizer = torch.optim.SGD(mlp_parameters, lr=LR, momentum=0.9, weight_decay=WD)
fpn_optimizer = torch.optim.SGD(fpn_parameters, lr=LR, momentum=0.9, weight_decay=WD)
submlp_optimizer = torch.optim.SGD(submlp_parameters, lr=LR, momentum=0.9, weight_decay=WD)

model = model.cuda()
model = DataParallel(model)

print(f"Starting training")
for epoch in range(1, NO_OF_EPOCHS + 1):
    model.train()
    train_correct = 0
    total = 0
    train_loss = 0
#===================================================================================================================
#
#                     Training 
#
#===================================================================================================================
    for i, data in enumerate(trainloader):
        image, label, img_raw = data[0].cuda(), data[1].cuda(), data[2]
        batch_size = image.size(0)

        #reset the grad
        rn_optimizer.zero_grad()
        fpn_optimizer.zero_grad()
        mlp_optimizer.zero_grad()
        submlp_optimizer.zero_grad()

        #input the batch images
        rn_logits, mlp_logits, sub_logits, _, top_k_prob = model(image, img_raw)

        #get the loss 
        fpn_loss = list_loss(sub_logits.view(batch_size * PROPOSAL_NUM, -1),
                                    label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)).view(batch_size, PROPOSAL_NUM)
        rn_loss = creterion(rn_logits, label)
        mlp_loss = creterion(mlp_logits, label)
        rank_loss = ranking_loss(top_k_prob, fpn_loss)
        submlp_loss = creterion(sub_logits.view(batch_size * PROPOSAL_NUM, -1),
                                 label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1))

        total_loss = rn_loss + rank_loss + mlp_loss + submlp_loss

        #take gradient
        total_loss.backward()

        #update the parameters
        rn_optimizer.step()
        fpn_optimizer.step()
        mlp_optimizer.step()
        submlp_optimizer.step()
        
        _, mlp_predict = torch.max(mlp_logits, 1)
        total += batch_size
        train_correct += torch.sum(mlp_predict.data == label.data)
        train_acc = float(train_correct) / total
        train_loss += mlp_loss.item() * batch_size
    train_loss = train_loss / total
    print()
    print("Train accuracy")
    print(
            'epoch:{} - train loss: {:.3f} and train acc: {:.3f} total sample: {}'.format(
                 epoch,
                 train_loss,
                 train_acc,
                 total))
    if epoch % SAVE_FREQ == 0 :
        train_loss = 0
        train_correct = 0
        total = 0
        model.eval()
#=======================================================================================================
#
#                   Validation
#
#=======================================================================================================
        for i, data in enumerate(valloader):
            with torch.no_grad():
                image, label, img_raw = data[0].cuda(), data[1].cuda(), data[2]
                batch_size = image.size(0)
                _, mlp_logits, _, _, _, = model(image, img_raw)
                # calculating loss
                mlp_loss = creterion(mlp_logits, label)

                # calculating accuracy
                _, mlp_predict = torch.max(mlp_logits, 1)
              
                total += batch_size
                train_correct += torch.sum(mlp_predict.data == label.data)
                train_loss += mlp_loss.item() * batch_size
        train_acc = float(train_correct) / total
        train_loss = train_loss / total
        print()
        print("Validation Accuracy")
        print(
             'epoch:{} - val loss: {:.3f} and val acc: {:.3f} total sample: {}'.format(
                 epoch,
                 train_loss,
                 train_acc,
                 total))
#===========================================================================================================
#
#                   Testing
#
#============================================================================================================
	      # evaluation on test set
        test_loss = 0
        test_correct = 0
        total = 0
        for i, data in enumerate(testloader):
            with torch.no_grad():
                image, label, img_raw = data[0].cuda(), data[1].cuda(), data[2]
                batch_size = image.size(0)
                _, mlp_logits, _, _, _ = model(image, img_raw)
                # calculating loss
                mlp_loss = creterion(mlp_logits, label)

                # calculating accuracy
                _, mlp_predict = torch.max(mlp_logits, 1)

                total += batch_size
                test_correct += torch.sum(mlp_predict.data == label.data)
                test_loss += mlp_loss.item() * batch_size
        test_acc = float(test_correct) / total
        test_loss = test_loss / total

        print()
        print("Test accuracy ")
        print(
             'epoch:{} - test loss: {:.3f} and test acc: {:.3f} total sample: {}'.format(
                 epoch,
                 test_loss,
                 test_acc,
                total))

print('training completed')

dataset loading done
Starting training

Train accuracy
epoch:1 - train loss: 1.197 and train acc: 0.703 total sample: 1148

Train accuracy
epoch:2 - train loss: 0.951 and train acc: 0.760 total sample: 1148

Validation Accuracy
epoch:2 - val loss: 0.638 and val acc: 0.855 total sample: 1148

Test accuracy 
epoch:2 - test loss: 1.164 and test acc: 0.804 total sample: 567

Train accuracy
epoch:3 - train loss: 0.671 and train acc: 0.780 total sample: 1148

Train accuracy
epoch:4 - train loss: 0.489 and train acc: 0.828 total sample: 1148

Validation Accuracy
epoch:4 - val loss: 0.871 and val acc: 0.856 total sample: 1148

Test accuracy 
epoch:4 - test loss: 1.769 and test acc: 0.795 total sample: 567
training completed
