In [4]:
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch
from torch import nn
import torch
from torch.utils.data import DataLoader, Subset
import os
import torchvision.models as models


In [2]:
class WeightAligning_NeuralNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_old_classes):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_old_classes = num_old_classes
        
        resnet = models.resnet18(pretrained=False)
        self.features = nn.Sequential(*list(resnet.children())[:-1])  # Obtener todas las capas excepto la capa de clasificación


        self.flatten = nn.Flatten()
        # Definir las capas lineales
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        # Propagación hacia adelante en la red
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        if self.output_size != self.num_old_classes: 
            # Separar los pesos de las capas lineales para clases antiguas y nuevas
            weights_old = self.fc3.weight[:self.num_old_classes, :]
            weights_new = self.fc3.weight[self.num_old_classes:, :]
            weights_old.requires_grad_(True) ;weights_new.requires_grad_(True)

            # Calcular las normas de los vectores de peso para clases antiguas y nuevas
            norm_old = torch.norm(weights_old, dim=1)
            norm_new = torch.norm(weights_new, dim=1)
            
            # Calcular el factor de normalización γ
            gamma = torch.mean(norm_old) / torch.mean(norm_new)
            # Aplicar el alineamiento de pesos (Weight Aligning)
            weights_new_aligned = gamma * weights_new
            weights_new_aligned.requires_grad_(True)

            # Crear una copia del tensor de pesos de fc3
            new_fc3_weight = self.fc3.weight.clone()

            # Asignar los pesos alineados a la parte correspondiente del tensor de pesos de fc3
            new_fc3_weight[self.num_old_classes:, :] = weights_new_aligned

            # Asignar el tensor de pesos modificado a fc3
            self.fc3.weight = nn.Parameter(new_fc3_weight)
            
            
            # Aplicar los pesos alineados para calcular la salida
            logits = self.fc3(x)

        else: logits = self.fc3(x)
        
        return logits
