# Dados

## Caminhos

In [None]:
datasets_path     = '/homeLocal/michelms/praticas-cv-cnn/datasets/'
models_path       = '/homeLocal/michelms/praticas-cv-cnn/models/'
tensorboard_path  = '/homeLocal/michelms/praticas-cv-cnn/Tensorboard/alexnet'

## Dataloader

In [None]:
from torch.utils.data import DataLoader
import torchvision

import matplotlib.pyplot as plt
  
import copy

import numpy as np

def my_imshow(img, dataset):
  
    if dataset == 'cifar10' :    
        img = img / 2 + 0.5     # unnormalize
    
    img = torchvision.utils.make_grid(img[:10],nrow=5)
    
    npimg = img.numpy()
    
    npimg = np.transpose(npimg, (1, 2, 0))  
    plt.axis('off')
    plt.imshow(npimg, interpolation='nearest')
    plt.show()

def show_images(train_loader, test_loader, dataset) :
    print('Train samples')
    # get some random training images
    dataiter = iter(train_loader)
    images = next(dataiter)[0]
    my_imshow(images, dataset)

    print('Test samples')
    # get some random training images
    dataiter = iter(test_loader)
    images = next(dataiter)[0]
    my_imshow(images, dataset)

def get_data_cifar10 ( batch_size, dataset, show_image=False ) :
  
    my_transform_test = torchvision.transforms.Compose([
                                torchvision.transforms.Resize(227),
                                torchvision.transforms.ToTensor(),
                                torchvision.transforms.Normalize(
                                    mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
                                    ])

    train_dataset = torchvision.datasets.CIFAR10(
                        root=f'{datasets_path}/train/', 
                        train=True, 
                        transform=torchvision.transforms.ToTensor(),
                        download=False
                                )
    test_dataset  = torchvision.datasets.CIFAR10(
                        root=f'{datasets_path}/test/',
                        train=False, 
                        transform=my_transform_test, 
                        download=False
                                )
    
    train_loader = DataLoader(train_dataset, 
                                batch_size=batch_size, 
                                shuffle=False
                                )
    test_loader  = DataLoader(test_dataset, 
                                batch_size=batch_size,
                                shuffle=False
                                )

    if show_image :
        show_images(train_loader, test_loader, dataset)
    
    return train_loader, test_loader, len(train_dataset)

In [None]:
_ = get_data_cifar10(batch_size=256, dataset='cifar10', show_image=True)
del _

# Rede

## Arquitetura

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AlexNet(nn.Module) :
  
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        
        self.extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=(11,11), padding=0, stride=4),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3,3), stride=2),
            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=(5,5), padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3,3), stride=2),
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=(3,3), padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=(3,3), padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=(3,3), padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3,3), stride=2)
        )
        self.flatten = nn.Flatten()
        self.classifier = nn.Sequential(
            nn.Linear(in_features=256*6*6, out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=num_classes),
            nn.LogSoftmax(dim=1)
        )
        
    def forward(self, x, debug=False):
        if debug : print('input',x.shape)
        y = self.extractor(x)
        if debug : print('extractor',y.shape)
        y = self.flatten(y)
        if debug : print('flatten',y.shape)
        y = self.classifier(y)
        if debug : print('classifier',y.shape)
        return y

## Informações sobre a rede

In [None]:
from torchsummary import summary

if torch.cuda.is_available():
    my_device = torch.device("cuda:0")
    print("Running on CUDA.")
else:
    my_device = torch.device("cpu")
    print("No Cuda Available")

net = AlexNet( num_classes=10 )
net = net.to(my_device)

a = torch.rand( (1, 3, 227, 227) )
a = a.to(my_device)

b = net( a , debug=True )

del a, b

In [None]:
summary(net, input_size=(3,227,227), batch_size=2000, device=my_device.type)
del net

## Treinamento

In [None]:
from torch.utils.tensorboard import SummaryWriter

import kornia.augmentation as K
import torchvision.transforms as transforms

import torch.optim 
import matplotlib.pyplot as plt
  
from datetime import datetime

from tqdm import tqdm

def plot_layers ( net , writer, epoch ) :
    layers = list(net.extractor.modules())
    
    layer_id = 1
    for layer in layers:
        if isinstance(layer, nn.Conv2d) :

            writer.add_histogram(f'Bias/conv{layer_id}', layer.bias, 
                                epoch )
            writer.add_histogram(f'Weight/conv{layer_id}', layer.weight, 
                                epoch )
            writer.add_histogram(f'Grad/conv{layer_id}', layer.weight.grad, 
                                    epoch )
            layer_id += 1


def train ( dataset, prefix=None, upper_bound=100.0, save=False, epochs=100, 
           lr=1e-1, device='cpu', debug=False, layers2tensorboard=True , batch_size=64) :

    if dataset == 'cifar10' :
        cifar_data = get_data_cifar10(batch_size, 
                                    dataset, 
                                    show_image=True
                                    )
        train_loader, test_loader, dataset_size = cifar_data
        num_classes = 10
    else :
        print('Error, dataloader is not implemented.')
        return None

    net = AlexNet( num_classes )
    net.to(device)

    optimizer = torch.optim.Adam(net.parameters(),lr=lr)
    criterion = nn.NLLLoss()
    
    now = datetime.now()
    suffix = now.strftime("%Y%m%d_%H%M%S")
    prefix = suffix if prefix is None else prefix + '-' + suffix  

    writer = SummaryWriter( log_dir=tensorboard_path+prefix )
        
    accuracies = []
    max_accuracy = -1.0  
    
    # Define augmentation transformations
    transform_train = nn.Sequential(
        K.RandomHorizontalFlip(p=0.5),
        K.RandomVerticalFlip(p=0.5),
        K.RandomRotation(degrees=15.0, p=0.5),
        K.RandomCrop(size=(227, 227), p=0.5),  
        K.ColorJitter(brightness=0.2, contrast=0.2, p=0.5),
        K.Resize(227),  
        K.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
    )

    for epoch in tqdm(range(epochs), desc='Training epochs...') :
        net.train()
        for idx, (train_x, train_label) in enumerate(train_loader):


            # print(type(train_x))
            # print(train_x.shape)
            
            train_x = transform_train(train_x)  # Augment the full batch

            # print(type(train_x))
            # print(train_x.shape)

            # my_imshow(train_x , 'cifar10')
            
            train_x = train_x.to(device)
            train_label = train_label.to(device)

            predict_y = net( train_x )
            
            # Loss:
            error = criterion( predict_y , train_label )

            writer.add_scalar( 'Loss/train', error.cpu().item(), 
                                idx+( epoch*(dataset_size//batch_size) ) )

            # Back propagation
            optimizer.zero_grad()
            error.backward()
            optimizer.step()
            
            # Accuracy:
            predict_ys = torch.max( predict_y, axis=1 )[1]
            correct    = torch.sum(predict_ys == train_label)

            writer.add_scalar( 'Accuracy/train', correct/train_x.size(0), 
                                idx+( epoch*(dataset_size//batch_size) ) )

            if debug and idx % 10 == 0 :
                print( f'idx: {idx:4d}, _error: {error.cpu().item():5.2f}' )

        if layers2tensorboard :
            plot_layers( net, writer, epoch )

        accuracy = validate(net, test_loader, device=device)
        accuracies.append(accuracy)
        writer.add_scalar( 'Accuracy/test', accuracy, epoch )
        
        if accuracy > max_accuracy :
            best_model = copy.deepcopy(net)
            max_accuracy = accuracy
            print("Saving Best Model with Accuracy: ", accuracy)
        
        print( f'Epoch: {epoch+1:3d} | Accuracy : {accuracy:7.4f}%' )

        if accuracy > upper_bound :
            break
    
    if save : 
        path = f'{models_path}AlexNet-{dataset}-{max_accuracy:.2f}.pkl'
        torch.save(best_model, path)
        print('Model saved in:',path)
    
    plt.plot(accuracies)

    writer.flush()
    writer.close()
    
    return best_model    

## Validação

In [None]:
def validate ( model , data , device='cpu') :

    model.eval()

    correct = 0
    sum = 0
    
    for idx, (test_x, test_label) in enumerate(data) : 
        test_x = test_x.to(device)
        test_label = test_label.to(device)
        predict_y = model( test_x ).detach()
        predict_ys = torch.max( predict_y, axis=1 )[1]
        sum = sum + test_x.size(0)
        correct = correct + torch.sum(predict_ys == test_label)
        correct = correct.cpu().item()
    
    return correct*100./sum

# Execução

## Treina

In [None]:
if torch.cuda.is_available():
    my_device = torch.device("cuda:0")
    print("CUDA Available.")
else:
    my_device = torch.device("cpu")
    print("No Cuda Available")

print(f'Running on {my_device.type}')
    
dataset = 'cifar10' # 'mnist' ou 'cifar10'
epochs = 50
lr = 1e-3
prefix = f'AlexNet-{dataset}-e-{epochs}-lr-{lr}'

torch.cuda.empty_cache()

net = train(dataset=dataset, epochs=epochs, device=my_device, save=True, 
            prefix=prefix, lr=lr, layers2tensorboard=True, batch_size=3076)

# Carregar dado e inferir

In [None]:
def sample_and_predict ( dataset='cifar10', seed=None ) :

    if seed is not None :
        np.random.seed(seed)

    i = np.random.randint(10000)
    
    if dataset == 'cifar10' :
        my_transform = torchvision.transforms.Compose([
                            torchvision.transforms.Resize(227),
                            torchvision.transforms.ToTensor(),
                            torchvision.transforms.Normalize(
                                mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
                                    ])
        data  = torchvision.datasets.CIFAR10(
                        root=f'{datasets_path}/test/',
                        train=False, 
                        download=False
                                )
        sample = data[i][0]
        plt.figure(figsize=(2,2))
        plt.axis('off')
        plt.imshow( sample )

    else :
        print('ERROR: Dataset not defined.')
        return False


    print( f'Sample id: {i:3d}' )
    
    x = my_transform(sample)
    print(x.shape)

    x = x.unsqueeze_(0)
    print(x.shape)

    x = x.to(my_device)
    
    output = net ( x )
    softmax = nn.functional.softmax(output, dim=-1)
    y = torch.max(softmax, 1)[1]
    confidence = torch.max(softmax, 1)[0]
    y = y.data.cpu().item()
    confidence = confidence.data.cpu().item()
    
    if dataset == 'cifar10' :
        dataset_classes = ('plane', 'car', 'bird', 'cat', 'deer', 
                        'dog', 'frog', 'horse', 'ship', 'truck')


        if y == data[i][1] : print('Hit')
        else: print('Miss')

        print( f'Predicted: {dataset_classes[y]} | Corrected: {dataset_classes[data[i][1]]} | Confidence: {confidence*100:.2f}%'  )
        
        return dataset_classes[y], dataset_classes[data[i][1]], confidence

    return y, data[i][1], confidence


In [None]:
_ = sample_and_predict()