In [None]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from skimage import io
from torchvision import transforms
from torch import nn, optim
from torch.nn import functional as F
from PIL import Image
import numpy as np
from imgaug import augmenters as iaa
import imgaug as ia
from torchvision import models


print(os.getcwd())

'''
class ImgAugTransform:
    def __init__(self):
        self.aug = iaa.Sequential([
            iaa.Resize((224, 224)),
            iaa.Sometimes(0.25, iaa.GaussianBlur(sigma=(0, 3.0))),
            iaa.Fliplr(0.5),
            iaa.Affine(rotate=(-20, 20), mode='symmetric'),
            iaa.Sometimes(0.25,
                          iaa.OneOf([iaa.Dropout(p=(0, 0.1)),
                                     iaa.CoarseDropout(0.1, size_percent=0.5)])),
            iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True)
        ])

    def __call__(self, img):
        img = np.array(img)
        return self.aug.augment_image(img)
        '''

class SsjDataset(Dataset):
    def __init__(self, archivocsv):
        self.arch_csv = pd.read_csv(archivocsv, engine='python', delimiter=';')
        self.ruta_imagenes = np.asarray(self.arch_csv.iloc[:, 0])
        self.arr_labels = np.asarray(self.arch_csv.iloc[:, 1])
        self.largo_datos = len(self.arch_csv.index)
        self.transformar = transformada = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

    def __len__(self):
        return len(self.arch_csv)

    def __getitem__(self, indice):
        nom_imagen = self.ruta_imagenes[indice]
        imagen = Image.open(nom_imagen)
        label_salida = self.arr_labels[indice]
        imagen = self.transformar(imagen)



        return imagen, label_salida


dataset = SsjDataset("Dataset_git.csv")
print(len(dataset))
tam_ent = int(0.6 * len(dataset))
tam_pru = len(dataset) - tam_ent
dispositivo = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_ent, set_pru = torch.utils.data.random_split(dataset, [tam_ent, tam_pru])
dl_ent = DataLoader(set_ent, batch_size=16, shuffle=True)
dl_pru = DataLoader(set_pru, batch_size=16, shuffle=False)


class Red(nn.Module):
    def __init__(self):
        super(Red, self).__init__()
        self.ent_conv1 = nn.Conv2d(3, 8, 5)
        self.conv1_conv2 = nn.Conv2d(8,12, 5)
        self.conv2_lineal1 = nn.Linear(12*5*5, 84)
        self.lineal1_lineal2 = nn.Linear(84, 84)
        self.lineal2_salida = nn.Linear(84, 3)
        self.pooling2d = torch.nn.MaxPool2d(2, 2)

    def forward(self, x):
        #x = x.permute(0, 3, 1, 2)
        x = self.pooling2d(F.relu(self.ent_conv1(x.float())))
        x = self.pooling2d(F.relu(self.conv1_conv2(x)))
        x = x.view(x.size(0), 12*5*5)
        x = F.relu(self.conv2_lineal1(x))
        x = F.relu(self.lineal1._lineal2(x))
        x = self.lineal2_salida(x)

        return x


modelo = models.resnet18(pretrained=True)
for param in modelo.parameters():
  param.requires_grad = False
ent_modelo = modelo.fc.in_features
modelo.fc = nn.Linear(ent_modelo, 3)
print(modelo)
modelo = modelo.to(dispositivo)
func_error = nn.CrossEntropyLoss()
optimizador = optim.SGD(modelo.parameters(), lr=0.001, momentum=0.9)
epocas = 20
error_acumulado = 0.0
modelo.train()

for epoca in range(epocas):
    for indice, datos in enumerate(dl_ent, 0):
        entradas, objetivos = datos
        entradas, objetivos = entradas.to(dispositivo), objetivos.to(dispositivo)
        optimizador.zero_grad()
        salidas = modelo(entradas)
        error = func_error(salidas, objetivos)
        error.backward()
        optimizador.step()
        error_acumulado += error.item()
        if indice % 7 == 3:  # print every 2000 mini-batches
            print('[%d, %5d] error: %.3f' %
            (epoca + 1, indice + 1, error_acumulado / 2000))
            error_acumulado = 0.0



torch.save(modelo.state_dict(), "SSJ_entrenado.pth")
print("Modelo entrenado y guardado")

correctas=0
totales=0

with torch.no_grad():
    modelo.eval()
    for datos in dl_pru:
        entradas, objetivos = datos
        entradas, objetivos = entradas.to(dispositivo), objetivos.to(dispositivo)
        salidas = modelo(entradas)
        _,prediccion = torch.max(salidas.data, 1)
        totales += objetivos.size(0)
        correctas += (prediccion == objetivos).sum().item()
print("Presición del modelo: ",100 * float(correctas / totales),"%")
