In [None]:
import functools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import config
import pytorch_lightning as pl
import torchmetrics
import torchvision.models as models
import ray
from ray import tune
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from math import ceil

torch.manual_seed(42)

ray.init(num_cpus=config.NUM_CPUS, num_gpus=config.NUM_GPUS, log_to_driver=False)
# %config InlineBackend.figure_formats = ['svg'] # Make plots high resolution
plt.rcParams['figure.figsize'] = (12, 12) # Increase figure size

In [None]:
class DataAugmentation(nn.Module):
    def __init__(self, apply_color_jitter = False, apply_random_augment = True, *args, **kwarrgs):
        super().__init__()
        self._apply_color_jitter = apply_color_jitter
        self._apply_random_augment = apply_random_augment
        
        self.rand_augment = transforms.RandAugment(*args, **kwarrgs)
        self.jitter = transforms.ColorJitter(0.5, 0.5, 0.5, 0.5)
    
    @torch.no_grad()
    def forward(self, x):
        if self._apply_color_jitter:
            x = self.jitter(x)
        if self._apply_random_augment:
            x = self.rand_augment(x)
        return x

In [None]:
class TrashNetDataModule(pl.LightningDataModule):
    def __init__(self, transfer_learning=False, augment=True, data_dir=config.ROOT_DIR, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS, pin_memory=config.PIN_MEMORY):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.augment = augment
        self.image_size = 224 if transfer_learning else config.IMAGE_SIZE
        # mean and standard deviations computed using seed = 42
        self.mean, self.std = (
            0.5389, 0.5123, 0.4846), (0.2201, 0.2178, 0.2323)

        self.augmentation = DataAugmentation()

    # Get indices of the train, validation, and test dataset split equally according to class distribution

    def get_indices(self, dataset):
        targets = np.asarray(dataset.targets)
        train_data_idx, test_idx = train_test_split(
            np.arange(len(targets)), test_size=config.TEST_SPLIT, stratify=targets)
        train_idx, val_idx = train_test_split(np.arange(len(
            train_data_idx)), test_size=config.VAL_SPLIT, stratify=targets[train_data_idx])
        train_idx, val_idx = train_data_idx[train_idx], train_data_idx[val_idx]
        return train_idx, val_idx, test_idx

    # Get samplers from indices
    def get_samplers(self, train_idx, val_idx, test_idx):
        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)
        test_sampler = SubsetRandomSampler(test_idx)
        return train_sampler, val_sampler, test_sampler

    def setup(self, stage=None):
        dataset = datasets.ImageFolder(config.ROOT_DIR, transform=transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor()
        ]))

        train_idx, val_idx, test_idx = self.get_indices(dataset)

        # Only calculate the mean and std of train and val dataset. Test idx is hidden.
        # self.mean, self.std = self.get_distribution(
        #     dataset, np.concatenate([train_idx, val_idx]))
        self.train_sampler, self.val_sampler, self.test_sampler = self.get_samplers(
            train_idx, val_idx, test_idx)

    def train_dataloader(self):
        transform = [self.augmentation] if self.augment else []
        transform = transforms.Compose(transform + [
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])
        dataset = datasets.ImageFolder(config.ROOT_DIR, transform=transform)
        return DataLoader(dataset, batch_size=self.batch_size, sampler=self.train_sampler, num_workers=self.num_workers, pin_memory=self.pin_memory)

    def val_dataloader(self):
        transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])
        dataset = datasets.ImageFolder(config.ROOT_DIR, transform=transform)
        return DataLoader(dataset, batch_size=self.batch_size,
                          sampler=self.val_sampler, num_workers=self.num_workers, pin_memory=self.pin_memory)

    def test_dataloader(self):
        transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])
        dataset = datasets.ImageFolder(config.ROOT_DIR, transform=transform)
        return DataLoader(dataset, batch_size=self.batch_size,
                          sampler=self.test_sampler, num_workers=self.num_workers, pin_memory=self.pin_memory)

    def get_distribution(self, data, indices):
        def prod(x): return functools.reduce(lambda a, b: a * b, x)
        shape = data[0][0].shape
        pixels = prod(shape[1:])
        print('Calculating Mean...')
        mean_unscaled = torch.zeros(shape[0])
        for index in indices:
            x = data[index][0].flatten(1)
            mean_unscaled += torch.sum(x, 1)
        mean = mean_unscaled / (len(data) * pixels)
        print('Calculating Std...')
        std_unscaled = torch.zeros_like(mean)
        for index in indices:
            x = data[index][0].flatten(1)
            std_unscaled += torch.sum(torch.square(x -
                                      mean.view(shape[0], 1)), 1)
        std = torch.sqrt(std_unscaled / (len(data) * pixels))
        mean = mean.item() if prod(mean.shape) == 1 else mean
        std = std.item() if prod(std.shape) == 1 else std
        print(f'Mean: {mean}, Std: {std}')
        return mean, std


In [None]:
class TrashBaseClass(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.accuracy = torchmetrics.Accuracy()

    def to_one_hot(self, x, ref):
        temp = torch.zeros(x.size(0), config.NUM_CLASSES).type_as(ref)
        return temp.scatter_(1, x.unsqueeze(1), 1.0)

    def log_metrics(self, metric, mode, pred, labels, loss, on_step=True, on_epoch=True):
        y_pred = pred.softmax(dim=-1)
        metric(y_pred, labels)
        metrics = {f'{mode}_accuracy': metric, f'{mode}_loss': loss}
        self.log_dict(metrics, on_step=on_step, on_epoch=on_epoch)
        return metrics

In [None]:
class TrashBaseline(TrashBaseClass):
    def __init__(self, *args, **kwargs):
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(config.IMAGE_SIZE * config.IMAGE_SIZE * config.INPUT_CHANNELS, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, config.NUM_CLASSES)
        )

    def forward(self, x):
        return self.net(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr = self.hparams.lr, betas=(config.B1, config.B2))
    
    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        one_hot_labels = self.to_one_hot(labels, imgs)
        pred = self(imgs).squeeze()
        loss = F.cross_entropy(pred, one_hot_labels)
        self.log_metrics(self.accuracy, 'train', pred, labels, loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        one_hot_labels = self.to_one_hot(labels, imgs)
        pred = self(imgs).squeeze()
        loss = F.cross_entropy(pred, one_hot_labels)
        metrics = self.log_metrics(self.accuracy, 'val', pred, labels, loss)
        return metrics

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        one_hot_labels = self.to_one_hot(labels, imgs)
        pred = self(imgs).squeeze()
        loss = F.cross_entropy(pred, one_hot_labels)
        metrics = self.log_metrics(self.accuracy, 'test', pred, labels, loss)
        return metrics

In [None]:
# model = Baseline(lr=0.0005)
# dm = TrashNetDataModule()
# trainer = pl.Trainer(max_epochs=config.EPOCHS, gpus=1, log_every_n_steps=20)
# trainer.fit(model, dm)

In [None]:
class TrashResnet50(TrashBaseClass):
    def __init__(self, *args, **kwargs):
        super().__init__()
        backbone = models.resnet50(pretrained=True)
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)
    
        num_target_classes = config.NUM_CLASSES
        self.classifier = nn.Linear(num_filters, num_target_classes)
        
    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        return x
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr = self.hparams.lr, betas=(config.B1, config.B2))
    
    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        one_hot_labels = self.to_one_hot(labels, imgs)
        pred = self(imgs)
        loss = F.cross_entropy(pred, one_hot_labels)
        self.log_metrics(self.accuracy, 'train', pred, labels, loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        one_hot_labels = self.to_one_hot(labels, imgs)
        pred = self(imgs)
        loss = F.cross_entropy(pred, one_hot_labels)
        metrics = self.log_metrics(self.accuracy, 'val', pred, labels, loss)
        return metrics

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        one_hot_labels = self.to_one_hot(labels, imgs)
        pred = self(imgs)
        loss = F.cross_entropy(pred, one_hot_labels)
        metrics = self.log_metrics(self.accuracy, 'test', pred, labels, loss)
        return metrics


In [None]:
def train_trash_tune(tune_config, checkpoint_dir=config.CHECKPOINT_DIR, transfer_learning=True, num_epochs=config.EPOCHS, num_gpus=config.NUM_GPUS):
    lr, batch_size = tune_config['lr'], tune_config['batch_size']
    model = TrashResnet50(lr=lr)
    dm = TrashNetDataModule(
        transfer_learning=transfer_learning, batch_size=batch_size)
    checkpoint_callback = pl.callbacks.model_checkpoint.ModelCheckpoint(
        dirpath=config.CHECKPOINT_DIR,
        filename='{epoch}-{val_loss:.2f}-{val_accuracy:.2f}',
        monitor='val_loss'
    )
    metrics = {'loss': 'val_loss', 'acc': 'val_accuracy'}
    tune_callback = TuneReportCallback(metrics, on='validation_end')
    trainer = pl.Trainer(max_epochs=num_epochs, gpus=ceil(num_gpus), log_every_n_steps=20,
                         enable_progress_bar = False, callbacks=[checkpoint_callback, tune_callback])
    trainer.fit(model, dm)


In [None]:
tune_config = {
    'lr': tune.loguniform(1e-5, 5e-2),
    'batch_size': tune.choice([32, 64, 128, 256]),
}

stopper = tune.stopper.TrialPlateauStopper('loss', 0.001)

analysis = tune.run(tune.with_parameters(train_trash_tune), 
    metric='loss',
    mode='min',
    local_dir='./results',
    resources_per_trial={
        'cpu': config.NUM_CPUS,
        'gpu': config.NUM_GPUS
    },
    stop=stopper,
    config=tune_config, num_samples=config.NUM_SAMPLES, name='tune_trash_resnet50', max_failures=-1)
print(analysis.best_config)