In [None]:
import torchvision
from torchvision import transforms
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset

import os
import torch.nn.functional as F
import torchvision.models as models

total_fase = 2 
save_model = True
load_model = False
device = "cuda"
data_dir = "../data" 
batch_size = 32

print("GPU activa:", torch.cuda.is_available(), "\nCantidad de GPs", torch.cuda.device_count())
#------------------------------------------------------------------------------------------------

# Conjunto de datos MNIST
train_data = torchvision.datasets.MNIST(
    root=data_dir,
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
eval_data = torchvision.datasets.MNIST(
    root=data_dir,
    train=False,
    download=True,
    transform=transforms.ToTensor()
)
num_classes = 10  
class_dataloaders = []

# Itera sobre cada clase y crea un DataLoader para esa clase
for class_idx in range(num_classes):
    # Obtén los índices para la clase actual
    class_indices = [i for i in range(len(train_data)) if train_data.targets[i] == class_idx]
    
    # Subconjunto de datos para la clase actual
    class_subset = Subset(train_data, class_indices)
    
    # DataLoader para la clase actual
    class_dataloader = DataLoader(class_subset, batch_size=batch_size, shuffle=True)
    
    # Agrega el DataLoader al listado
    class_dataloaders.append(class_dataloader)


eval_dataloader = DataLoader(eval_data, batch_size=10000, shuffle=True)
print("Se cargaron los datos correctamente")


In [None]:
# Definir una arquitectura de ResNet modificada para MNIST
class ResNetMNIST(nn.Module):
    def __init__(self):
        super(ResNetMNIST, self).__init__()
        self.resnet = torchvision.models.resnet18(pretrained=False)
        # Cambiar la primera capa convolucional para aceptar imágenes en escala de grises
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Cambiar la última capa lineal para tener 10 clases de salida
        self.resnet.fc = nn.Identity()

    def forward(self, x):
        x = self.resnet(x)
        return x

# Inicializar el modelo ResNetMNIST
model = ResNetMNIST(); model.to("cuda")

In [None]:
class_dataloaders

In [None]:
def extractor(dataloader, model):
    model.eval()
    all_features = []
    all_labels = []

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            # Obtener las características (vectores de características) del modelo
            features = model(X)
            all_features.append(features.cpu())
            all_labels.append(y.cpu())

    # Concatenar todas las características y etiquetas
    all_features = torch.cat(all_features, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    return all_features, all_labels

### Generate embeddings with datatrain

In [None]:
model.eval()
embedding = list()
for data_i in class_dataloaders:
    features, labels = extractor(data_i, model)
    embedding.append((features, labels))

In [37]:
def prototype(embedding):
    return [ (torch.sum(pt[0], dim =0)/(pt[0].shape[0])).to(device) for pt in embedding ]

### Esto deberían ser centroides entonces en dado caso de graficarlos deberian salir en el centro

In [38]:
p_i = prototype(embedding)

In [None]:
[(idx ,i) for idx, i in enumerate(p_i)]

In [73]:
p_i[0].size(0)

512

### Clasificacion

In [76]:
def classifier(p_i, images):
    num_images = images.size(0)
    num_p_i = p_i[0].size(0)
    predictions = torch.zeros(num_images, dtype=torch.long)
    
    for i in range(num_images):
        min_distance, predict = float("inf"), None
        for idx in range(10):
            distance = torch.norm(p_i[idx] - images[i])
            if distance < min_distance:
                min_distance = distance
                predict = idx
        predictions[i] = predict
    
    return predictions

In [None]:
classifier(p_i, embedding[1][0][0])

In [91]:
def test_loop(dataloader, model, classifier, p_i):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    tarjet_prediction = []  # Lista para almacenar las etiquetas reales y predicciones


    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            embedding = model(X).to(device); # print(embedding.shape)
            pred = classifier(p_i, embedding)
            print(pred.shape, y.shape)
            #test_loss += loss_fn(pred, y).item()
            #correct += (pred == y).type(torch.float).sum().item()
            tarjet_prediction.extend(list(zip(y.cpu().numpy(), pred.cpu().numpy())))

    correct /= size
    
    directorio= 'logs'
    if not os.path.exists(directorio):
        os.makedirs(directorio)
    with open(f'logs/epoch_{0}_CC_{0}.txt', 'w') as archivo:
            # Escribe el valor de la variable en el archivo
            archivo.write(str(tarjet_prediction))
    print(f'El valor prediciones se ha guardado en el archivo.txt')

In [92]:
epochs = 1

log_accuracy_loss = []

test_loop(eval_dataloader, model, classifier, p_i)


torch.Size([10000]) torch.Size([10000])
El valor prediciones se ha guardado en el archivo.txt
