In [1]:
import torchvision
from torchvision import transforms
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset

import os
import torch.nn.functional as F
import torchvision.models as models

total_fase = 2 
save_model = True
load_model = False
device = "cuda"
data_dir = "../data" 
batch_size = 32

print("GPU activa:", torch.cuda.is_available(), "\nCantidad de GPs", torch.cuda.device_count())
#------------------------------------------------------------------------------------------------

# Conjunto de datos MNIST
train_data = torchvision.datasets.MNIST(
    root=data_dir,
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
eval_data = torchvision.datasets.MNIST(
    root=data_dir,
    train=False,
    download=True,
    transform=transforms.ToTensor()
)
num_classes = 10  
class_dataloaders = []

# Itera sobre cada clase y crea un DataLoader para esa clase
for class_idx in range(num_classes):
    # Obtén los índices para la clase actual
    class_indices = [i for i in range(len(train_data)) if train_data.targets[i] == class_idx]
    
    # Subconjunto de datos para la clase actual
    class_subset = Subset(train_data, class_indices)
    
    # DataLoader para la clase actual
    class_dataloader = DataLoader(class_subset, batch_size=batch_size, shuffle=True)
    
    # Agrega el DataLoader al listado
    class_dataloaders.append(class_dataloader)


eval_dataloader = DataLoader(eval_data, batch_size=10000, shuffle=True)
print("Se cargaron los datos correctamente")


GPU activa: True 
Cantidad de GPs 1
Se cargaron los datos correctamente


In [31]:
# Definir una arquitectura de ResNet modificada para MNIST
class ResNetMNIST(nn.Module):
    def __init__(self):
        super(ResNetMNIST, self).__init__()
        self.resnet = torchvision.models.resnet18(pretrained=False)
        # Cambiar la primera capa convolucional para aceptar imágenes en escala de grises
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Cambiar la última capa lineal para tener 10 clases de salida
        self.resnet.fc = nn.Identity()

    def forward(self, x):
        x = self.resnet(x)
        return x

# Inicializar el modelo ResNetMNIST
model = ResNetMNIST(); model.to("cuda")



ResNetMNIST(
  (resnet): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_

In [14]:
class_dataloaders

[<torch.utils.data.dataloader.DataLoader at 0x7f9f92608150>,
 <torch.utils.data.dataloader.DataLoader at 0x7f9f92608c50>,
 <torch.utils.data.dataloader.DataLoader at 0x7f9f92608ed0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f9f92608fd0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f9f926090d0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f9f92609210>,
 <torch.utils.data.dataloader.DataLoader at 0x7f9f92609310>,
 <torch.utils.data.dataloader.DataLoader at 0x7f9f92609410>,
 <torch.utils.data.dataloader.DataLoader at 0x7f9f92609610>,
 <torch.utils.data.dataloader.DataLoader at 0x7f9f92609690>]

In [15]:
def extractor(dataloader, model):
    model.eval()
    all_features = []
    all_labels = []

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            # Obtener las características (vectores de características) del modelo
            features = model(X)
            all_features.append(features.cpu())
            all_labels.append(y.cpu())

    # Concatenar todas las características y etiquetas
    all_features = torch.cat(all_features, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    return all_features, all_labels

In [32]:
model.eval()
embedding = list()
for data_i in class_dataloaders:
    features, labels = extractor(data_i, model)
    embedding.append((features, labels))

In [41]:
def prototype(embedding):
    return [ torch.sum(pt[0], dim =0)/(pt[0].shape[0]) for pt in embedding ]

### Esto deberían ser centroides entonces en dado caso de graficarlos deberian salir en el centro

In [42]:
p_i = prototype(embedding)