In [2]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [3]:
!pip install torchinfo
!pip install pytorch-lightning
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchinfo
  Downloading torchinfo-1.7.1-py3-none-any.whl (22 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.7.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.8.5.post0-py3-none-any.whl (800 kB)
[K     |████████████████████████████████| 800 kB 33.2 MB/s 
[?25hCollecting lightning-utilities!=0.4.0,>=0.3.0
  Downloading lightning_utilities-0.4.2-py3-none-any.whl (16 kB)
Collecting tensorboardX>=2.2
  Downloading tensorboardX-2.5.1-py2.py3-none-any.whl (125 kB)
[K     |████████████████████████████████| 125 kB 76.3 MB/s 
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.11.0-py3-none-any.whl (512 kB)
[K     |████████████████████████████████| 512 kB 76.8 MB/s 
Installing collected packages: torchmetrics, tensorboardX, light

In [4]:
!cp -r 'gdrive/MyDrive/DLC2021' 'DLC2021'
!cp -r 'gdrive/MyDrive/our_data' 'our_data'


In [5]:
!unzip 'DLC2021/crops.zip' -d './DLC2021/'
!unzip 'our_data/our_data.zip' -d './our_data/'
A


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: ./our_data/crops/6865.png  
  inflating: ./our_data/crops/6885.png  
  inflating: ./our_data/crops/6887.png  
  inflating: ./our_data/crops/6890.png  
  inflating: ./our_data/crops/6912.png  
  inflating: ./our_data/crops/6913.png  
  inflating: ./our_data/crops/6914.png  
  inflating: ./our_data/crops/6917.png  
  inflating: ./our_data/crops/6919.png  
  inflating: ./our_data/crops/6921.png  
  inflating: ./our_data/crops/6922.png  
  inflating: ./our_data/crops/6924.png  
  inflating: ./our_data/crops/6929.png  
  inflating: ./our_data/crops/6931.png  
  inflating: ./our_data/crops/6941.png  
  inflating: ./our_data/crops/6943.png  
  inflating: ./our_data/crops/6945.png  
  inflating: ./our_data/crops/6946.png  
  inflating: ./our_data/crops/6947.png  
  inflating: ./our_data/crops/6956.png  
  inflating: ./our_data/crops/6966.png  
  inflating: ./our_data/crops/6967.png  
  inflating: ./our_data/crops/697

In [1]:
!nvidia-smi

Sun Dec 18 13:34:14 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   72C    P0    33W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
import torch
import numpy as np
from os import listdir, readlink
from os.path import join
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
import cv2

def get_parameters(models):
    r"""
    This function get all the parameter recursive exploring the dictionary, list or module in input if necessary
    """
    parameters = []
    if isinstance(models, list):
        for model in models:
            parameters += get_parameters(model)
    elif isinstance(models, dict):
        for model in models.values():
            parameters += get_parameters(model)
    else: # models is actually a single pytorch model
        parameters += [p for p in models.parameters() if p.requires_grad]
    return parameters


class DLC2021_FDA(Dataset):
    '''
    Dataset to handle image loading and transformations
    '''

    def __init__(self, root_source, root_target, split_source, split_target, beta_fda=(0,0.01), image_size=(224,224)):
        super(DLC2021_FDA, self).__init__()

        self.tgt_root = root_target
        self.src_root = root_source
        self.tgt_split = split_target
        self.src_split = split_source
        self.image_size = image_size
        self.beta_fda = beta_fda

        # read annots data
        with open(join(self.src_root, 'annots.json')) as f:
            self.src_data = json.load(f)
        
        with open(join(self.src_root, f'{self.src_split}.json')) as f:
            self.src_ids = json.load(f)

        with open(join(self.tgt_root, 'annots.json')) as f:
            self.tgt_data = json.load(f)
        
        with open(join(self.tgt_root, f'{self.tgt_split}.json')) as f:
            self.tgt_ids = json.load(f)

    def __getitem__(self, index):
    
        # get and load infor about src image
        index = str(self.src_ids[index])
        img_label = self.src_data[index]['label'] 
        img_id = self.src_data[index]['id'] 
        path = join(self.src_root, 'crops',f'{img_id}.png')
        img = cv2.imread(path)
        img = torch.tensor(np.asarray(img).transpose(2,0,1), dtype=torch.float)
        
        # get and load info about tgt image
        tgt_index = np.random.randint(0,len(self.tgt_ids))
        tgt_id = str(self.tgt_ids[tgt_index])
        path = join(self.tgt_root, 'crops', f'{tgt_id}.png')
        tgt_img = cv2.imread(path)
        tgt_img = torch.tensor(np.asarray(tgt_img).transpose(2,0,1), dtype=torch.float)
        
        # apply fda and fix bounds
        mod_img = apply_fda(img, tgt_img, self.beta_fda)
        # fix bounds by re-normalizing between 0 and 1
        min_b, max_b = torch.amin(mod_img, dim=(1,2)), torch.amax(mod_img, dim=(1,2))
        min_b, max_b = min_b.view(3,1,1), max_b.view(3,1,1)

        mod_img = 255 * (mod_img - min_b) / (max_b - min_b)

        #debug_augs(img, mod_img, tgt_img, img_id, tgt_id)

        img = mod_img


        return img, float(img_label), img_id

    def __len__(self):
        return len(self.src_ids)


def debug_augs(src, mod_src, tgt, src_id, tgt_id):

    src = src.numpy().transpose(1,2,0)
    mod_src = mod_src.numpy().transpose(1,2,0)
    tgt = tgt.numpy().transpose(1,2,0)

    cv2.imwrite(f'src_{src_id}.png', src)
    cv2.imwrite(f'mod_src_{src_id}.png', mod_src)
    cv2.imwrite(f'tgt_{tgt_id}.png', tgt)



class DLC2021(Dataset):
    '''
    Dataset to handle image loading and transformations
    '''

    def __init__(self, root, split, image_size=(224,224)):
        super(DLC2021, self).__init__()

        self.root = root
        self.split = split
        self.image_size=[224,224]

        # read annots data
        with open(join(root, 'annots.json')) as f:
            self.data = json.load(f)
        
        with open(join(root, f'{split}.json')) as f:
            self.ids = json.load(f)

    def __getitem__(self, index):
    
        index = str(self.ids[index])
        # read image, produce sample. Removes last two chars (\n) from path
        img_label = self.data[index]['label'] 
        img_id = self.data[index]['id'] 
        path = join(self.root, 'crops',f'{img_id}.png')

        img = cv2.imread(path)
        
        img = torch.tensor(np.asarray(img).transpose(2,0,1), dtype=torch.float)

        # eventually apply augmentations
        #t_sample = self.transforms(img) 

        return img, float(img_label), img_id

    def __len__(self):
        return len(self.ids)


In [7]:
import math

def apply_fda(source_img, target_img, betas):

    min_b, max_b = betas
    
    # using fft, get source in target in fourier domain
    # this function outputs the fourier transform in complex number
    target_f = torch.fft.fft2(target_img.clone())
    source_f = torch.fft.fft2(source_img.clone()) 

    # switch from complex number formulation to amplitude and phase
    target_amp, source_amp = torch.abs(target_f), torch.abs(source_f)
    # phase of target is not used, only source is necessary
    source_phase = torch.angle(source_f)

    # amplitude is given with positive frequencies first and then negative frequencies
    # fftshift change this to center 0 frequence at the center of the image
    # this makes spectrum substitution easier in the code
    source_shifted_amp = torch.fft.fftshift(source_amp, dim=(-2, -1))
    target_shifted_amp = torch.fft.fftshift(target_amp, dim=(-2, -1))

    # clone source, may be necessary in the frequency interval setting
    source_shifted_amp_clone = source_shifted_amp.clone()

    # get image shape (assuming square image) and compute bounds
    dim = source_shifted_amp.shape[1]
    bound = int(math.floor(dim * max_b))
    img_center = int(math.floor(dim/2))

    l_bound = img_center - bound
    u_bound = img_center + bound + 1

    # substitute frequency
    source_shifted_amp[:,l_bound:u_bound, l_bound:u_bound] = target_shifted_amp[:,l_bound:u_bound, l_bound:u_bound]

    # if this is true, we are in the frequency interval setting
    if min_b > 0:
        # repeat the same operations as above, but restore to the original source the frequencies lower than the bound
        bound = int(math.floor(dim * min_b))
        img_center = int(math.floor(dim/2))

        l_bound = img_center - bound
        u_bound = img_center + bound + 1

        #print('Applying lower bound: {}, {}'.format(l_bound, u_bound))

        source_shifted_amp[:,l_bound:u_bound, l_bound:u_bound] = source_shifted_amp_clone[:,l_bound:u_bound, l_bound:u_bound]

    # finally, get back to the default amplitude representation by inverting the shift operation
    source_new_amp = torch.fft.ifftshift(source_shifted_amp, dim=(-2, -1))
    
    # to reconstruct the image, we need the complex representation of the fourier transform
    # to obtain this, multiply the FDA-augmented amplitude with the exponential of the original source phase
    source_new_f =  source_new_amp * torch.exp(1j * source_phase)

    # finally, invert the fourier transform and get an image back
    source_new_img = torch.fft.ifft2(source_new_f, dim=(-2, -1))
    source_new_img = torch.real(source_new_img)

    return source_new_img

In [8]:
EXP_ROOT = 'gdrive/MyDrive/TACV_exps'

In [9]:
import math
import json
import torchinfo
from datetime import datetime
import torch.nn as nn
import logging
import subprocess
import torchvision.models as models
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar
from pytorch_lightning.loggers import WandbLogger

class DocRecDetection(LightningModule):
    
    """
    This class is a PyTorch Lightning system and contain the core of the major steps made during the training of a NN
    """

    def __init__(self, config):
        """
        This functions setup the model and the NN loss
        """
        super().__init__()

        self.args = config
        self.init_pipeline()
        
        self.act = torch.nn.Sigmoid()

    def init_pipeline(self):
        
        exp_root = EXP_ROOT + '/{}'.format(self.args['exp_name'])
        # get pretrained Resnet50 on Imagenet

        if self.args['backbone'] == 'mobilenet':
            backbone = models.mobilenet_v3_small(weights="DEFAULT")
            _ = backbone.classifier.pop(-1) 
            backbone.classifier.append(nn.Linear(1024,1))
            #backbone = models.efficientnet_v2_s(weights="DEFAULT")
        
        elif self.args['backbone'] == 'efficientnet':
            backbone = models.efficientnet_v2_s(weights="DEFAULT")
            _ = backbone.classifier.pop(-1) 
            backbone.classifier.append(nn.Linear(1280,1))
        
        elif self.args['backbone'] == 'resnet50':
            backbone = models.resnet50(weights="DEFAULT")
            backbone.fc = nn.Linear(2048,1)
        
        else:
            raise RuntimeError(f'Model {self.args["backbone"]} not supported.')

        self.backbone = backbone
        
        print(torchinfo.summary(self.backbone, (16,3,224,224)))
        # get loss function
        self.loss_fn = torch.nn.BCELoss()
        
        # create experiment folder and save a copy of current config
        subprocess.call(f"mkdir {exp_root}", shell=True)
        with open(f'{exp_root}/config.json','w') as f:
            json.dump(CONFIG,f)
        
        # create logger
        self.std_logger = self.get_logger()
        
        
    def get_logger(self):

        # get logger
        path_log = EXP_ROOT + '/{}/log_train.txt'.format(self.args['exp_name']) 
        logger = logging.getLogger('log')
        logger.setLevel(logging.INFO)
        # this is necessary in Colab, apparently the handler is retained when executing the cell
        logger.handlers.clear()
        # get formatter
        formatter = logging.Formatter('%(asctime)s - %(message)s')

        # get file handler and add it to logger
        # in this way, the logger will write to both the console and the specified file
        fh = logging.FileHandler(path_log, mode='w')
        fh.setFormatter(formatter)
        logger.addHandler(fh)

        # get console handler
        ch = logging.StreamHandler()
        ch.setFormatter(formatter)
        logger.addHandler(ch)

        logger.propagate = False

        return logger
    
    def get_callbacks(self):

        args=self.args
        cpt_callback = ModelCheckpoint(
            dirpath=join(EXP_ROOT,args['exp_name']),
            every_n_epochs=args['save_freq'],
            save_top_k=-1,
            filename='{epoch:04d}'
        )

        bar_callback = TQDMProgressBar(refresh_rate=10)

        return [cpt_callback, bar_callback]

    def get_wb_logger(self):

        args = self.args
        
        wb_logger = WandbLogger(
            save_dir= self.args['exp_name'],
            project='TACV',
            name=self.args['exp_name']
        )

        self.wb_logger = wb_logger

        return [wb_logger]

    def configure_optimizers(self):
        """
        This functions setup the optimizer and the scheduler
        """
        parameters = get_parameters(self.backbone)

        if self.args['optimizer'] == 'sgd':
            optimizer = torch.optim.SGD(
                params=parameters,
                lr=self.args['lr'],
                weight_decay=self.args['w_decay'],
                nesterov=False)
            
        elif self.args['optimizer'] == 'adam':
            optimizer = torch.optim.AdamW(
                params=parameters,
                lr=self.args['lr'],
                weight_decay=self.args['w_decay'])
        else:
            raise RuntimeError('Optimizer type {} not implemented!'.format(self.args['optimizer']))

        self.optimizer = optimizer
        self.args['step'] = self.args['n_epochs']        

        if self.args['scheduler'] == 'step':
            # Learning rate is reduced after 50%, 75% and 90% of samples like in Segdriven original implementation
            part_milestones = [math.ceil(self.args['n_epochs'] * step) for step in [0.5,0.75,0.9]]
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=part_milestones,
                gamma=0.1)
                
        elif self.args['scheduler'] == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=self.args['step'] - 1,
                eta_min=0.1 * self.args['lr'])
        
        # dummy scheduler which will never really lower the lr
        if self.args['scheduler'] is None:
            # Learning rate is reduced after 50%, 75% and 90% of samples like in Segdriven original implementation
            part_milestones = [math.ceil(self.args['n_epochs']*1.5)]
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=part_milestones,
                gamma=0.1)
        else:
            raise RuntimeError('Scheduler type {} not implemented!'.format(self.args['scheduler']))
        
        return [self.optimizer], [scheduler]
            
    def structured_log(self, loss, acc, prec, rec, prefix):

        all_metrics = {}
        all_metrics[f'loss/{prefix}'] = loss
        all_metrics[f'accuracy/{prefix}'] = acc
        all_metrics[f'precision/{prefix}'] = prec
        all_metrics[f'recall/{prefix}'] = rec

        self.log_dict(
            all_metrics, 
            on_step=True, 
            on_epoch=True, 
            logger=True,
            sync_dist=True,
            rank_zero_only=True,
            batch_size=self.args['bs']
        )

    def forward_batch(self, batch):

        imgs = batch[0]
        labels = batch[1]
        
        preds = self.forward(imgs)
        
        labels = labels.to(torch.float32)

        preds = self.act(preds)

        return preds, labels

############### ON START FUNCTIONS

    def on_train_start(self):

        self.std_logger.info('Starting training.')
        self.wb_logger.watch(self.backbone, log_freq=100)
        for k, v in self.args.items():
            self.wb_logger.experiment.config.update({k:v})


        return super().on_train_start()
    
    def on_test_start(self):
        
        self.pred_file = EXP_ROOT + '/{}/preds_{}.txt'.format(self.args['exp_name'],self.args['test_split'])

        with open(self.pred_file,'w') as f:
            f.write("id,gt,pred\n")

        self.std_logger.info('Starting testing.')

        return super().on_test_start()

############## STEP FUNCTIONS

    def training_step(self, batch, batch_idx):

        pred, labels = self.forward_batch(batch)

        loss = self.loss_fn(pred, labels)

        acc, prec, rec = compute_metrics(pred, labels)
        
        self.structured_log(loss, acc, prec, rec, prefix='train')

        return loss

    def validation_step(self, batch, batch_idx):

        pred, labels = self.forward_batch(batch)
        
        loss = self.loss_fn(pred, labels)

        acc, prec, rec = compute_metrics(pred, labels)
                
        self.structured_log(loss, acc, prec, rec, prefix='valid')

        return loss    

    def test_step(self, batch, batch_idx):

        with open(self.pred_file,'a') as f:
            
            idxs = batch[2]
            preds, labels = self.forward_batch(batch)

            preds = torch.where(preds>0.5, 1, 0)

            for idx_i, pred_i, label_i in zip(idxs, preds, labels):

                idx_i = int(idx_i.item())
                pred_i = int(pred_i.item())
                label_i = int(label_i.item())
                f.write(f'{idx_i},{label_i},{pred_i}\n')

############# END FUNCTIONS

    def on_test_end(self):

        fp,fn,tp,tn = 0.,0.,0.,0.

        with open(self.pred_file, 'r') as f:
            lines = f.readlines()
        
        # not considering first row (header)
        for line in lines[1:]:
            
            _,gt,pred = line.split(',')
            
            gt, pred = int(gt), int(pred)
            if gt == 0:
                if pred == 0:
                    tn += 1
                else:
                    fp += 1
            else:
                if pred == 0:
                    fn += 1
                else:
                    tp += 1
        
        recall = 100 * tp / (tp + fn)
        precision = 100 * tp / (tp + fp)
        accuracy = 100 * (tp + tn) / (tp + tn + fp + fn)

        self.std_logger.info("Recall: {:2.2f}".format(float(recall)))
        self.std_logger.info("Precision: {:2.2f}".format(float(precision)))
        self.std_logger.info("Accuracy: {:2.2f}".format(float(accuracy)))
    
        return super().on_test_end()
    
############# DATALOADERS

    def get_train_dataloader(self):
        
        args = self.args
        
        if args['apply_fda']:

            # using train of DLC as source and train of our dataset as target
            dataset = DLC2021_FDA(
                root_source=args['train_data_root'],
                root_target=args['test_data_root'],
                split_source=args['train_split'],
                split_target=args['train_split'],
                beta_fda=args['beta_fda']
            )

        else:
            dataset = DLC2021(
                root=args['train_data_root'],
                split=args['train_split']
            )

        self.std_logger.info('TRAIN samples: {}'.format(dataset.__len__()))

        # get dataloader
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=args['bs'],
            num_workers=args['n_workers'],
            shuffle=True,
            drop_last=True
        )

        return dataloader
    
    def get_valid_dataloader(self):

        args = self.args
        
        dataset = DLC2021(
            root=args['train_data_root'],
            split=args['valid_split']
        )

        self.std_logger.info('VALID samples: {}'.format(dataset.__len__()))

        # get dataloader
        dataloader = DataLoader(
            dataset=dataset,
            num_workers=args['n_workers'],
            batch_size=args['bs'],
            shuffle=False,
            drop_last=False
        )

        return dataloader

    def get_test_dataloader(self):

        # test is same as valid actually
        args = self.args
        
        dataset = DLC2021(
            root=args['train_data_root'],
            split=args['test_split']
        )

        self.std_logger.info('TEST samples: {}'.format(dataset.__len__()))

        # get dataloader
        dataloader = DataLoader(
            dataset=dataset,
            num_workers=args['n_workers'],
            batch_size=args['bs'],
            shuffle=False,
            drop_last=False
        )

        return dataloader

    def get_test_custom_dataloader(self):

        # test is same as valid actually
        args = self.args
        
        dataset = DLC2021(
            root=args['test_data_root'],
            split=args['test_split']
        )

        self.std_logger.info('TEST samples: {}'.format(dataset.__len__()))

        # get dataloader
        dataloader = DataLoader(
            dataset=dataset,
            num_workers=args['n_workers'],
            batch_size=args['bs'],
            shuffle=False,
            drop_last=False
        )

        return dataloader

    def forward(self, x):

        feats = self.backbone(x)
        feats = feats.squeeze()
        return feats


In [10]:
import torch

def compute_metrics(preds, labels):

    preds = torch.where(preds>0.5, 1, 0)

    fp,fn,tp,tn = 0.,0.,0.,0.
    
    for pred, lab in zip(preds, labels):

        pred = int(pred.item())
        gt = int(lab.item())

        if gt == 0:
            if pred == 0:
                tn += 1
            else:
                fp += 1
        else:
            if pred == 0:
                fn += 1
            else:
                tp += 1
    
    recall = tp / (tp + fn) if (tp+fn) > 0 else 0
    precision = tp / (tp + fp) if (tp+fp) > 0 else 0
    accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp+tn+fp+fn) > 0 else 0

    return accuracy, precision, recall

    

In [15]:
# standard training
from pytorch_lightning import Trainer

CONFIG = {
    # exp configs
    'exp_name': 'mobilenet_fda_low_1',
    'save_freq' : 5,
    'n_epochs': 10,
    'valid_freq': 1,
    'backbone': 'mobilenet',
    # training configs
    'bs': 32,
    'n_workers': 0,
    'lr': 1e-4,
    'w_decay': 4e-5,
    'optimizer': 'adam',
    'scheduler': None,
    # data configs
    'train_split': 'train',
    'valid_split': 'test',
    'test_split': 'test',
    'train_data_root': 'DLC2021',
    'test_data_root': 'our_data',
    'img_size': (224,224),
    'apply_fda': True,
    'beta_fda': (0.02,0.05),
}

args = CONFIG

system = DocRecDetection(args)

trainer = Trainer(
    logger = system.get_wb_logger(),
    enable_checkpointing=True,
    callbacks=system.get_callbacks(),
    accelerator='gpu',
    log_every_n_steps=1,
    auto_select_gpus=True,
    num_sanity_val_steps=2,
    check_val_every_n_epoch=args['valid_freq'],
    max_epochs=args['n_epochs']
)

train_data = system.get_train_dataloader()
valid_data = system.get_test_custom_dataloader()
test_data = system.get_test_custom_dataloader()

'''

last_epoch=4
trainer.fit(
    system, 
    train_dataloaders=train_data, 
    val_dataloaders=valid_data,
    ckpt_path=EXP_ROOT+'/{}/epoch={:04d}.ckpt'.format(args['exp_name'],last_epoch)
)
'''

last_epoch = args['n_epochs'] - 1
ckpt_path = EXP_ROOT + '/{}/epoch={:04d}.ckpt'.format(args['exp_name'],last_epoch)
trained_model = DocRecDetection.load_from_checkpoint(ckpt_path, config=args)

trainer.test(model=trained_model, dataloaders=test_data)



  rank_zero_warn(
INFO:pytorch_lightning.trainer.connectors.accelerator_connector:Auto select gpus: [0]


Layer (type:depth-idx)                             Output Shape              Param #
MobileNetV3                                        [16, 1]                   --
├─Sequential: 1-1                                  [16, 576, 7, 7]           --
│    └─Conv2dNormActivation: 2-1                   [16, 16, 112, 112]        --
│    │    └─Conv2d: 3-1                            [16, 16, 112, 112]        432
│    │    └─BatchNorm2d: 3-2                       [16, 16, 112, 112]        32
│    │    └─Hardswish: 3-3                         [16, 16, 112, 112]        --
│    └─InvertedResidual: 2-2                       [16, 16, 56, 56]          --
│    │    └─Sequential: 3-4                        [16, 16, 56, 56]          744
│    └─InvertedResidual: 2-3                       [16, 24, 28, 28]          --
│    │    └─Sequential: 3-5                        [16, 24, 28, 28]          3,864
│    └─InvertedResidual: 2-4                       [16, 24, 28, 28]          --
│    │    └─Sequential: 3-6   

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
2022-12-18 15:09:01,202 - TRAIN samples: 31931
2022-12-18 15:09:01,232 - TEST samples: 3928
2022-12-18 15:09:01,271 - TEST samples: 3928
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Layer (type:depth-idx)                             Output Shape              Param #
MobileNetV3                                        [16, 1]                   --
├─Sequential: 1-1                                  [16, 576, 7, 7]           --
│    └─Conv2dNormActivation: 2-1                   [16, 16, 112, 112]        --
│    │    └─Conv2d: 3-1                            [16, 16, 112, 112]        432
│    │    └─BatchNorm2d: 3-2                       [16, 16, 112, 112]        32
│    │    └─Hardswish: 3-3                         [16, 16, 112, 112]        --
│    └─InvertedResidual: 2-2                       [16, 16, 56, 56]          --
│    │    └─Sequential: 3-4                        [16, 16, 56, 56]          744
│    └─InvertedResidual: 2-3                       [16, 24, 28, 28]          --
│    │    └─Sequential: 3-5                        [16, 24, 28, 28]          3,864
│    └─InvertedResidual: 2-4                       [16, 24, 28, 28]          --
│    │    └─Sequential: 3-6   

Testing: 0it [00:00, ?it/s]

2022-12-18 15:09:01,761 - Starting testing.
2022-12-18 15:09:19,442 - Recall: 31.91
2022-12-18 15:09:19,446 - Precision: 78.03
2022-12-18 15:09:19,450 - Accuracy: 60.46


[{}]