In [5]:

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD, RMSprop
from torchvision import disable_beta_transforms_warning
disable_beta_transforms_warning()
import torchvision.transforms.v2 as tf
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset, DataLoader
from tqdm import tqdm
import yaml
import os
import matplotlib.pyplot as plt
import random
from fixcap import FixCapsNet
from torchsummary import summary

workers = os.cpu_count()
pu = 'cuda' if torch.cuda.is_available() else 'cpu'

assert pu == 'cuda', 'Connect to the GPU'

print(f'using {pu}, number of workers: {workers}')

torch.cuda.empty_cache()


using cuda, number of workers: 8


In [6]:
home_folder = '/home/user'
WORKSPACE = f'{home_folder}/Skin_Diseases_Detection'
DATASET_FOLDER = f"{home_folder}/skdi_dataset/base_dir"
TRAINING_FOLDER = f'{DATASET_FOLDER}/train_dir'
TESTING_FOLDER = f'{DATASET_FOLDER}/val_dir'
VALIDATION_FOLDER = f'{DATASET_FOLDER}/val_dir'
CHECKPOINT_FOLDER = f'{WORKSPACE}/checkpoints'
STATUS_FOLDER = f'{WORKSPACE}/status'
PLOT_FOLDER = f'{WORKSPACE}/plots'
PARAM_FOLDER = f'{WORKSPACE}/param'
CONFIG_FOLDER = f'{WORKSPACE}/config'
MODEL_FOLDER = f'{WORKSPACE}/model'

class Dataset:
    def __init__(self, train_trans, test_trans, train_batch_size = 32):
        self.train_ds = ImageFolder(TRAINING_FOLDER, train_trans)
        self.test_ds = ImageFolder(TESTING_FOLDER, test_trans)
        self.val_ds = ImageFolder(VALIDATION_FOLDER, test_trans)

        self.train_dl = DataLoader(self.train_ds, batch_size=train_batch_size, shuffle=True, num_workers=workers, pin_memory=True)
        self.test_dl = DataLoader(self.test_ds, batch_size=train_batch_size, shuffle=True, num_workers=workers)
        self.val_dl = DataLoader(self.val_ds, batch_size=train_batch_size, shuffle=True, num_workers=workers, pin_memory=True)

def make_cnn(dataset: ImageFolder, hid_layers = [64, 64],
            act_fn='relu', max_pool = None, pooling_after_layers = 2, dropout = 0.2, batch_norm=True,
            groups =1, bias=False, conv_layers=[[32, 3, 1],
                                                [16, 3, 1]]):
    
    img = dataset.__getitem__(0)[0]
    input_shape = img.shape
    num_of_classes = len(dataset.classes)
    layers = []
    activation_fun = {'relu': nn.ReLU(inplace=True), 'softplus':nn.Softplus(), 'tanh':nn.Tanh(), 'elu': nn.ELU()}

    assert pooling_after_layers <= len(conv_layers), 'exceeding the number conv layers..'

    in_chann, inp_h, inp_w = input_shape
    for ind, conv in enumerate(conv_layers):
        out_chann, filter_size, stride = conv
        layers.append(nn.Conv2d(in_chann, out_chann, filter_size, stride, groups=groups, bias=bias))
        if batch_norm:
            layers.append(nn.BatchNorm2d(out_chann))
        layers.append(activation_fun[act_fn])

        out_h = (inp_h - filter_size)//stride + 1
        out_w = (inp_w - filter_size)//stride + 1
        inp_h = out_h
        inp_w = out_w

        if max_pool is not None and ((ind+1) % pooling_after_layers == 0 or ind == (len(conv_layers) - 1)):
            layers.append(nn.MaxPool2d(max_pool[0], max_pool[1]))
            out_h = (inp_h - max_pool[0])//max_pool[1] + 1
            out_w = (inp_w - max_pool[0])//max_pool[1] + 1
            inp_h = out_h
            inp_w = out_w
        in_chann = out_chann

    layers.append(nn.Flatten())
    layers.append(nn.Linear(inp_h*inp_w*in_chann, hid_layers[0]))
    layers.append(activation_fun[act_fn])
    if len(hid_layers) > 1:
        dim_pairs = zip(hid_layers[:-1], hid_layers[1:])
        for in_dim, out_dim in list(dim_pairs):
            layers.append(nn.Linear(in_dim, out_dim))
            if dropout is not None:
                layers.append(nn.Dropout(p=dropout))
            layers.append(activation_fun[act_fn])

    layers.append(nn.Linear(hid_layers[-1], num_of_classes))
    return nn.Sequential(*layers)

def convert(**kwargs):
    return kwargs

class Utils:
    def __init__(self):
        self.model = None
        self.model_file = ''
        self.plot_file = ''
        self.min_loss = 0

    def read_file(self, path):
        file = open(path, 'r')
        file.seek(0)
        info = file.readline()
        file.close()
        return info

    def write_file(self, path, content):
        mode = 'w'
        if path == self.plot_file:
            mode = '+a'
        file = open(path, mode=mode)
        file.write(content)
        file.close()

    def create_file(self, path):
        with open(path, 'w') as file:
            pass
        file.close()

    def create_checkpoint_file(self, num):
        path = f'{CHECKPOINT_FOLDER}/checkpoint_{num}.pth'
        file = open(path, 'w')
        file.close()
        return path
    
    def save_config(self, args: dict):
        if not os.path.exists(self.config_file):
            self.create_file(self.config_file)
        with open(self.config_file, 'w') as file:
            yaml.safe_dump(args, file)
        file.close()

    def check_status_file(self):
        if not os.path.exists(self.status_file):
            self.create_file(self.status_file)
        checkpath = self.read_file(self.status_file)
        epoch = 0
        if checkpath != '':
            epoch = self.load_checkpoint(checkpath)
            file = open(self.plot_file, 'r')
            lines = file.readlines()
            file = open(self.plot_file, 'w')
            file.writelines(lines[:epoch+1])
            file.close()
        else:
            file = open(self.plot_file, 'w')
            file.close()
            self.write_file(self.plot_file,'Train_loss,Train_acc,Valid_loss,Valid_acc\n')
            self.model.train()
        return epoch

    def write_plot_data(self, data:list):
        str_data = ','.join(map(str, data))
        self.write_file(self.plot_file, f'{str_data}\n')

    def save_checkpoint(self, epoch, checkpath):
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'optim_state_dict': self.optim.state_dict(),
            'epoch': epoch
        }
        file = open(self.status_file, 'w')
        file.write(checkpath)
        file.close()
        torch.save(checkpoint, checkpath)
        print('checkpoint saved..')
    
    def load_checkpoint(self, checkpath):
        print('loading checkpoint..')
        checkpoint = torch.load(checkpath)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optim.load_state_dict(checkpoint['optim_state_dict'])
        self.model.train()
        print('checkpoint loaded...')
        return checkpoint['epoch']
    
    def save_check_interval(self, epoch, interval=50):
        num_checks = 7
        if not(epoch % interval) and epoch > 0:
            checkpath = self.create_checkpoint_file(epoch)
            self.save_checkpoint(epoch, checkpath)
            files = len(os.listdir(CHECKPOINT_FOLDER))
            if files > num_checks+1:
                os.remove(f'{CHECKPOINT_FOLDER}/checkpoint_{epoch-num_checks}.pth')

    
    def load_model(self):
        print('loading model...')
        self.model.load_state_dict(torch.load(self.model_file))
        self.model.eval()
        print('model loaded...')

    def save_model(self):
        torch.save(self.model.state_dict(), self.model_file)
        print('model saved...')

    def save_best_model(self, param, acc_param=True):
        if acc_param:
           param_ = max(param, self.param)
        else:
            param_ = min(param, self.param)
        if self.param != param_:
            self.param = param_
            self.write_file(self.param_file, f'{param_}')
            self.save_model()

    def check_param_file(self):
        if os.path.exists(self.param_file):
            param = float(self.read_file(self.param_file))
        else:
            self.create_file(self.param_file)
            param = -1000.0
            self.write_file(self.param_file, f'{param}')
        return param

class skdi_detector(Utils):
    def __init__(self,params):
        self.params = params
        self.dataset = Dataset(train_trans=params.train_trans, test_trans=params.test_trans,
                               train_batch_size=params.batch_size)
        self.loss_fn = CrossEntropyLoss()
        self.name = params.name
        self.model_file = f'{MODEL_FOLDER}/{self.name}_model.pth'
        self.status_file = f'{STATUS_FOLDER}/{self.name}_status.txt'
        self.plot_file = f'{PLOT_FOLDER}/{self.name}_plot.txt'
        self.param_file = f'{PARAM_FOLDER}/{self.name}_param.txt'
        self.config_file = f'{CONFIG_FOLDER}/{self.name}_config.yaml'
        self.param = self.check_param_file()
        self.metric_param = params.metric_param
        self.clip_grad = params.clip_grad
        self.acc_param = False
        if self.metric_param in ['train_acc', 'val_acc']:
            self.acc_param = True
    
    def to_one_hot(self, x, length):
        batch_size = x.size(0)
        x_one_hot = torch.zeros(batch_size, length)
        for i in range(batch_size):
            x_one_hot[i, x[i]] = 1.0
        return x_one_hot

    def train_step(self):
        self.model.train()
        train_loss, train_acc = 0.0, 0.0
        for _, (data, target) in enumerate(self.dataset.train_dl):
            tar_indices = target
            tar_one_hot = self.to_one_hot(tar_indices, 7)
            data, target = data.to(pu), tar_one_hot.to(pu)

            self.optim.zero_grad()
            output = self.model(data)

            v_mag = torch.sqrt(torch.sum(output**2, dim=2, keepdim=True)) 
            pred = v_mag.data.max(1, keepdim=True)[1].cpu().squeeze()
            train_acc += pred.eq(tar_indices.view_as(pred)).squeeze().sum().item()
            
            loss = self.model.loss(output, target)
            loss.backward()
            # torch.nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), self.clip_grad)
            self.optim.step()
            train_loss += loss.item()

        train_loss /= len(self.dataset.train_dl)
        train_acc /= len(self.dataset.train_dl)

        return train_loss, train_acc
    
    def validate_step(self):
        self.model.eval()
        val_loss, val_acc = 0.0, 0.0
        with torch.no_grad():
            for _, (data, target) in enumerate(self.dataset.val_dl):
                tar_indices = target
                tar_one_hot = self.to_one_hot(tar_indices, 7)
                
                data, target = data.to(pu), tar_one_hot.to(pu)

                output = self.model(data)
                v_mag = torch.sqrt(torch.sum(output**2, dim=2, keepdim=True)) 
                pred = v_mag.data.max(1, keepdim=True)[1].cpu().squeeze()
                val_acc += pred.eq(tar_indices.view_as(pred)).squeeze().sum().item()

                loss = self.model.loss(output, target)
                val_loss += loss.item()
        
        val_loss /= len(self.dataset.val_dl)
        val_acc /= len(self.dataset.val_dl)
        return val_loss, val_acc
    
    def create_model(self):
        if self.params.custom_model:
            conv_outputs = 128 #128_Feature_map
            num_primary_units = 8
            primary_unit_size = 16 * 6 * 6
            output_unit_size = 16
            img_size = 299
            self.model = FixCapsNet(conv_inputs= 3,
                                conv_outputs=conv_outputs,
                                num_primary_units=num_primary_units,
                                primary_unit_size=primary_unit_size,
                                output_unit_size=output_unit_size,
                                num_classes=7,
                                init_weights=True,mode="128").to(pu)
        else:
            self.model = make_cnn(dataset=self.dataset.train_ds, hid_layers=self.params.hid_layers, act_fn=self.params.act_fn,
                                max_pool=self.params.max_pool, pooling_after_layers=self.params.pool_after_layers,
                                batch_norm=self.params.batch_norm, conv_layers=self.params.conv_layers, dropout=self.params.dropout).to(pu)
        
        self.optim = Adam(self.model.parameters(), lr=self.params.lr)
        print(f'Model: {self.model}')
        print(f'Number of classes: {self.dataset.train_ds.classes}')
        print(f'Input image size: {self.dataset.train_ds.__getitem__(0)[0][0].shape}')
        print(f'total number of parameters: {sum([p.numel() for p in self.model.parameters()])}')
        
    def plot_images(self, n_imgs):
        fig, axes = plt.subplots(1, n_imgs, figsize=(10, 10))
        for i in range(n_imgs):
            ind = random.randint(0, len(self.dataset.train_ds)-1)
            img, label = self.dataset.train_ds.__getitem__(ind)
            img = img.numpy()
            img = img.transpose((1, 2, 0))
            axes.flat[i].set_title(label)
            axes.flat[i].imshow(img)
            axes.flat[i].axis('off')

    def train(self):
        epochs = self.params.epochs
        epoch = 0
        epoch = self.check_status_file()

        print(f'training for {epochs - epoch} epochs....')
        for ep in tqdm(range(epoch, epochs+1)):
            train_loss, train_acc = self.train_step()
            val_loss, val_acc = self.validate_step()

            metric_param = {'train_loss': train_loss, 'train_acc': train_acc,
                            'val_loss': val_loss, 'val_acc': val_acc}
            
            print(f'epochs: {ep}\t{train_loss = :.4f}\t{train_acc = :.4f}\t{val_loss = :.4f}\t{val_acc = :.4f}')
            self.write_plot_data([train_loss, train_acc, val_loss, val_acc])
            self.save_check_interval(epoch=ep, interval=1)
            self.save_best_model(acc_param=self.acc_param, param=metric_param[self.metric_param])
        
        print('Finished Training....')
    



In [8]:
class Params:
    def __init__(self):
        self.name = 'model_2'
        self.custom_model = True
        self.conv_layers = [[128, 11, 2],
                            [128, 3, 2],
                            [512, 3, 2],
                            [128, 5, 1],
                            [128, 3, 1],
                            [512, 3, 1]]
                            # [512, 5, 1]]
                            # [128, 3, 1],
                            # [256, 3, 1],
                            # [512, 3, 1]]
                            # [512, 3, 1],
                            # [512, 3, 1]]
        self.max_pool = [2, 2]
        self.pool_after_layers = 3
        self.act_fn = 'relu'
        self.batch_norm = False
        self.dropout = None
        self.hid_layers = [256, 256]
        self.lr = 0.00001
        self.epochs = 100
        self.clip_grad = 0.5
        self.metric_param = 'val_acc'
        self.batch_size = 168
        self.test_trans = tf.Compose([
            tf.Resize(size =(299, 299)),
            tf.ToTensor(),
            tf.Normalize([0.5, 0.5, 0.5],
                         [0.5, 0.5, 0.5])
        ])
        self.train_trans = tf.Compose([
            tf.Resize(size=(299, 299)),
            tf.ToTensor(),
            tf.Normalize([0.5, 0.5, 0.5],
                         [0.5, 0.5, 0.5])
        ])

In [9]:
torch.cuda.empty_cache()
params = Params()

agent = skdi_detector(params)

agent.create_model()
print(f'training dataset size: {len(agent.dataset.train_ds)}')
print(f'validation dataset size: {len(agent.dataset.val_ds)}')

# agent.plot_images(3)






Model: FixCapsNet(
  (Convolution): Sequential(
    (0): Conv2d(3, 128, kernel_size=(31, 31), stride=(2, 2), bias=False)
    (1): ReLU(inplace=True)
    (2): FractionalMaxPool2d()
  )
  (CBAM): Conv_CBAM(
    (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): Hardswish()
    (ca): ChannelAttention(
      (avg_pool): AdaptiveAvgPool2d(output_size=1)
      (max_pool): AdaptiveMaxPool2d(output_size=1)
      (fc1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu1): ReLU()
      (fc2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (sigmoid): Sigmoid()
    )
    (sa): SpatialAttention(
      (conv1): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (sigmoid): Sigmoid()
    )
  )
  (primary): Primary_Caps(
    (Caps_0): ConvUnit(
      (Cpas): Sequential(
        (0): Conv2d(128, 16, kern

In [11]:
agent.train()

training for 100 epochs....


  0%|          | 0/101 [00:06<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.46 GiB (GPU 0; 5.62 GiB total capacity; 355.67 MiB already allocated; 491.62 MiB free; 388.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF