In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, num_classes, rank_config):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.rank_config = rank_config

        # a basic ViT with self-attention
        self.embedding = nn.Linear(3 * 32 * 32, 512)  # Flatten input image
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=512, nhead=8)
            for _ in range(6)  # Number of transformer layers
        ])
        self.fc = nn.Linear(512, num_classes)

        # we should create low-rank layers based on rank_config, For example: we can apply SVD-based decomposition to self.fc.weight
        # rank can be adjusted based on rank_config

    def forward(self, x):
        # Flattening input image
        x = x.view(x.size(0), -1)

        # Embedding layer
        x = self.embedding(x)

        # Transformer layers
        for layer in self.transformer_layers:
            x = layer(x)

        # Classification head
        x = self.fc(x)

        return x

# Example usage:
rank_config = {"rank": 16}  # Customize rank as needed
model = VisionTransformer(num_classes=10, rank_config=rank_config)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

#CIFAR-10 train and test dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

# Class labels
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [None]:
rank_config = {
    'num_ranks': 4,                        # Number of ranks (e.g., 4 for 4 ranks)
    'rank_candidates': [16, 32, 64, 128],  # List of candidate rank values
    'filter_strategy': 'top_k',            # Candidate filtering strategy (e.g., 'top_k' or 'threshold')
    'filter_value': 2                      # Value for filtering (e.g., top 2 candidates)
}
# Create FLORA supernet
model = VisionTransformer(num_classes=1000, rank_config=rank_config)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
log_interval = 100

for epoch in range(num_epochs):
    for batch_idx, (images, labels) in enumerate(trainloader):
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{batch_idx+1}/{len(trainloader)}] Loss: {loss.item()}")

  self.pid = os.fork()


Epoch [1/10] Batch [1/782] Loss: 6.880486011505127
Epoch [1/10] Batch [101/782] Loss: 2.2895030975341797
Epoch [1/10] Batch [201/782] Loss: 2.364046812057495
Epoch [1/10] Batch [301/782] Loss: 2.2933807373046875
Epoch [1/10] Batch [401/782] Loss: 2.368079900741577
Epoch [1/10] Batch [501/782] Loss: 2.407742500305176
Epoch [1/10] Batch [601/782] Loss: 2.307753562927246
Epoch [1/10] Batch [701/782] Loss: 2.3778953552246094
Epoch [2/10] Batch [1/782] Loss: 2.355701208114624
Epoch [2/10] Batch [101/782] Loss: 2.3230113983154297
Epoch [2/10] Batch [201/782] Loss: 2.338007688522339
Epoch [2/10] Batch [301/782] Loss: 2.318798542022705
Epoch [2/10] Batch [401/782] Loss: 2.3363893032073975
Epoch [2/10] Batch [501/782] Loss: 2.339705228805542
Epoch [2/10] Batch [601/782] Loss: 2.3411717414855957
Epoch [2/10] Batch [701/782] Loss: 2.3137009143829346
Epoch [3/10] Batch [1/782] Loss: 2.2913360595703125
Epoch [3/10] Batch [101/782] Loss: 2.311530590057373
Epoch [3/10] Batch [201/782] Loss: 2.3093311