<a href="https://colab.research.google.com/github/jcmachicao/MachineLearningAvanzado_UC_2024/blob/main/U2_MLADV__Torch_NAS__JCMV.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchdyn

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchdyn.models import Sequential

In [None]:
# Define the search space for the NAS (DARTS)
class MixedOp(nn.Module):
    def __init__(self, C, stride):
        super(MixedOp, self).__init__()
        self.ops = nn.ModuleList()
        for primitive in ["identity", "max_pool_3x3", "avg_pool_3x3", "skip_connect", "sep_conv_3x3"]:
            op = OPS[primitive](C, stride, False)
            self.ops.append(op)

    def forward(self, x, weights):
        return sum(w * op(x) for w, op in zip(weights, self.ops))

class NASCell(nn.Module):
    def __init__(self, steps, C):
        super(NASCell, self).__init__()
        self.preprocess0 = nn.Sequential(nn.Conv2d(C, C, kernel_size=1, padding=0, bias=False), nn.BatchNorm2d(C))
        self.preprocess1 = nn.Sequential(nn.Conv2d(C, C, kernel_size=1, padding=0, bias=False), nn.BatchNorm2d(C))
        self._steps = steps
        self._ops = nn.ModuleList()
        for i in range(self._steps):
            op = MixedOp(C, 1)
            self._ops.append(op)

    def forward(self, s0, s1, weights):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)
        states = [s0, s1]
        offset = 0
        for i in range(self._steps):
            s = sum(self._ops[offset + j](h, weights[offset + j]) for j, h in enumerate(states))
            offset += len(states)
            states.append(s)
        return torch.cat(states[-self._steps:], dim=1)


In [None]:
# Define the search space operations
OPS = {
    "identity": lambda C, stride, affine: Identity(),
    "max_pool_3x3": lambda C, stride, affine: MaxPool(C, 3, stride, 1, affine=affine),
    "avg_pool_3x3": lambda C, stride, affine: AvgPool(C, 3, stride, 1, affine=affine),
    "skip_connect": lambda C, stride, affine: SkipConnect(C, stride, affine),
    "sep_conv_3x3": lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
}

# Define the final model using NAS
class NASNet(nn.Module):
    def __init__(self, C, steps, num_classes):
        super(NASNet, self).__init__()
        self.stem0 = nn.Sequential(nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C))
        self.stem1 = nn.Sequential(nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C))

        self.cells = nn.ModuleList()
        for i in range(steps):
            cell = NASCell(steps, C)
            self.cells += [cell]
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(C * 4, num_classes)

    def forward(self, x, weights):
        s0 = s1 = self.stem0(x)
        for cell in self.cells:
            weights = weights.view(-1).softmax(dim=0)
            s0, s1 = s1, cell(s0, s1, weights)
        out = self.global_pooling(s1)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out


In [None]:
# Example: Create and train the NASNet using random weights (for simplicity)
C = 16  # Number of channels
steps = 4  # Number of NAS steps
num_classes = 10  # Number of output classes

# Random weights (for simplicity, in practice, you would use an optimization algorithm)
weights = torch.randn(1, 14, requires_grad=True)

# Create the NASNet model
nasnet = NASNet(C, steps, num_classes)

# Example input (batch size, channels, height, width)
input_data = torch.randn(1, 3, 32, 32)

# Forward pass with random weights
output = nasnet(input_data, weights)

# Print the output shape
print("Output shape:", output.shape)