# 0. Set Environment

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install pytorch-lightning

In [None]:
cd /content/drive/My Drive/Projects/full_bodyshot_classification/src-pytorch

# 1. Build Model

In [None]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.logging import TensorBoardLogger
from lightning_rexnetv1 import *

In [None]:
# 앞에 _붙으면 import 안됨, conv층과 conv_swish import

def _add_conv(out, in_channels, channels, kernel=1, stride=1, pad=0,
              num_group=1, active=True, relu6=False):
    out.append(nn.Conv2d(in_channels, channels, kernel, stride, pad, groups=num_group, bias=False))
    out.append(nn.BatchNorm2d(channels))
    if active:
        out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True))


def _add_conv_swish(out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1):
    out.append(nn.Conv2d(in_channels, channels, kernel, stride, pad, groups=num_group, bias=False))
    out.append(nn.BatchNorm2d(channels))
    out.append(Swish())

In [None]:
# Set a Callback 
class PrintCallback(pl.Callback):
    def on_train_start(self, trainer, pl_module):
        print('*** Training starts...')
    def on_train_end(self, trainer, pl_module):
        print('*** Training is done.')

In [None]:
# Set a model
# Can scale model by 'width'
# Issue : 1.5x 
class CustomReXNetV1(pl.LightningModule):
    """
    """
    def __init__(self, hparams, input_ch=16, final_ch=180, width_mult=1.0, depth_mult=1.0, classes=1000,
                 use_se=True,
                 se_ratio=12,
                 dropout_ratio=0.5,
                 bn_momentum=0.9):
        super(CustomReXNetV1, self).__init__()

        self.hparams = hparams
        self.path = hparams['path']
        self.lr = hparams['lr']
        self.batch_size = hparams['batch_size']
        self.num_classes = hparams['num_classes']
        self.width_mult = hparams['mult'] # Add mult for select scale
        self.depth_mult = hparams['mult'] # Add mult for select scale
        
        self.pretrain = True if hparams['pretrain'].lower() == 'true' else False

        if self.pretrain:
            self.model = ReXNetV1(width_mult=self.width_mult)#,depth_mult=self.depth_mult)
            self.model.load_state_dict(torch.load('./rexnet_pretrained/rexnetv1_{}x.pth'.format(str(hparams['mult'])))) # load_scale
        self.save_hyperparameters()

        layers = [1, 2, 2, 3, 3, 5]
        strides = [1, 2, 2, 2, 1, 2]
        layers = [ceil(element * depth_mult) for element in layers]
        strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], [])
        ts = [1] * layers[0] + [6] * sum(layers[1:])
        self.depth = sum(layers[:]) * 3

        stem_channel = 32 / width_mult if width_mult < 1.0 else 32
        inplanes = input_ch / width_mult if width_mult < 1.0 else input_ch

        features = []
        in_channels_group = []
        channels_group = []

        _add_conv_swish(features, 3, int(round(stem_channel * width_mult)), kernel=3, stride=2, pad=1)

        # The following channel configuration is a simple instance to make each layer become an expand layer.
        for i in range(self.depth // 3):
            if i == 0:
                in_channels_group.append(int(round(stem_channel * width_mult)))
                channels_group.append(int(round(inplanes * width_mult)))
            else:
                in_channels_group.append(int(round(inplanes * width_mult)))
                inplanes += final_ch / (self.depth // 3 * 1.0)
                channels_group.append(int(round(inplanes * width_mult)))

        if use_se:
            use_ses = [False] * (layers[0] + layers[1]) + [True] * sum(layers[2:])
        else:
            use_ses = [False] * sum(layers[:])

        for block_idx, (in_c, c, t, s, se) in enumerate(zip(in_channels_group, channels_group, ts, strides, use_ses)):
            features.append(LinearBottleneck(in_channels=in_c,
                                             channels=c,
                                             t=t,
                                             stride=s,
                                             use_se=se, se_ratio=se_ratio))

        pen_channels = int(1280 * width_mult)
        _add_conv_swish(features, c, pen_channels)

        features.append(nn.AdaptiveAvgPool2d(1))
        self.features = nn.Sequential(*features)
        self.output = nn.Sequential(
            nn.Dropout(dropout_ratio),
            nn.Conv2d(pen_channels, classes, 1, bias=True))

        # additional
        self.fc = nn.Sequential(
            nn.Linear(1000, 2)
        )

    def forward(self, x):
        if self.pretrain:
            x = self.model(x)
            x = self.fc(x)
            return x

        x = self.features(x)
        x = self.output(x).squeeze()
        return x

    def train_dataloader(self):
        train_dset = get_dataset(self.path, 'train')
        train_loader = get_dataloader(train_dset, self.batch_size)
        return train_loader

    def val_dataloader(self):
        val_dset = get_dataset(self.path, 'valid')
        val_loader = get_dataloader(val_dset, self.batch_size)
        return val_loader

    def test_dataloader(self):
        test_dset = get_dataset(self.path, 'test')
        test_loader = get_dataloader(test_dset, self.batch_size)
        return test_loader

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch, batch_idx):
        data, target = batch
        y_hat = self(data)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(y_hat, target)

        # acc
        correct = 0
        _, predicted = torch.max(y_hat, 1)
        correct += predicted.eq(target).sum().item()
        accuracy = 100*(correct/target.size(0))

        return {'loss':loss, 'trn_acc':torch.tensor(accuracy)}

    def training_epoch_end(self, outputs):
        train_loss_mean = torch.stack([x['loss'] for x in outputs]).mean()
        train_acc_mean = torch.stack([x['trn_acc'] for x in outputs]).mean()
        log = {'avg_trn_loss':train_loss_mean, 'avg_trn_acc':train_acc_mean}
        print(log)
        return {'log':log, 'trn_loss':train_loss_mean, 'trn_acc':train_acc_mean}

    def validation_step(self, batch, batch_idx):
        data, target = batch
        y_hat = self(data)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(y_hat, target)

        # acc
        correct = 0
        _, predicted = torch.max(y_hat, 1)
        correct += predicted.eq(target).sum().item()
        accuracy = 100*(correct/target.size(0))

        return {'val_loss':loss, 'val_acc':torch.tensor(accuracy)}

    def validation_epoch_end(self, outputs):
        val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_acc_mean = torch.stack([x['val_acc'] for x in outputs]).mean()
        tensorboard_log = {'avg_val_loss':val_loss_mean, 'avg_val_acc':val_acc_mean}
        print(tensorboard_log)
        return {'val_loss':val_loss_mean, 'val_acc':val_acc_mean, 'log':tensorboard_log}

    def test_step(self, batch, batch_idx):
        data, target = batch
        y_hat = self(data)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(y_hat, target)

        # acc
        correct = 0
        _, predicted = torch.max(y_hat, 1)
        correct += predicted.eq(target).sum().item()
        accuracy = 100*(correct/target.size(0))

        return {'test_loss':loss, 'test_acc':torch.tensor(accuracy)}
    
    def test_epoch_end(self, outputs):
        test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
        test_acc_mean = torch.stack([x['test_acc'] for x in outputs]).mean()
        tensorboard_log = {'avg_test_loss':test_loss_mean, 'avg_test_acc':test_acc_mean}
        print(tensorboard_log)
        return {'test_loss':test_loss_mean, 'test_acc':test_acc_mean, 'log':tensorboard_log}

In [None]:
# Set a Hyper Parameters !
# == Args
hparams={
    'gpus' : 1,
    'lr' : 0.0001,
    'batch_size' : 16,
    'epoch' : 200,
    'path' : '/content/drive/My Drive/Projects/full_bodyshot_classification/src-pytorch/data',
    'num_classes' : 2 ,
    'pretrain' : 'True',
    'mult' : 2.0
    }



In [None]:
model = CustomReXNetV1(hparams)

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
# default used by the Trainer
checkpoint_callback = ModelCheckpoint(
    filepath='experiments',
    save_top_k=True,
    verbose=True,
    monitor='val_loss',
    mode='min',
    prefix=''
)

In [None]:
# Train
trainer = pl.Trainer(
                gpus=hparams['gpus'],
                checkpoint_callback=ModelCheckpoint()
                max_epochs=hparams['epoch'],
                log_save_interval=100
                check_val_every_n_epoch=1,

                early_stop_callback=True,
            )
trainer.fit(model)

In [None]:
#저장 
torch.save(model.state_dict(),'RexNetx20_epoch9_dr0.2.pt')