In [1]:
import torch
import torch.nn as nn

class VanillaSequencerBlock(nn.Module):
    def __init__(self, input_size, hidden_size, mlp_input_size, mlp_output_size):
        super(VanillaSequencerBlock, self).__init__()

        
        self.normal_layer = nn.LayerNorm(input_size)

        
        self.bilstm = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)

       
        self.normalization_after_merge = nn.LayerNorm(input_size)

        
        self.channel_mlp = nn.Linear(mlp_input_size, mlp_output_size)

    def forward(self, x):
        if isinstance(x, list):
            x = torch.stack(x)
        batch_size,channel, height= x.size()
       
        outputs=[]
        for index in range(batch_size):
            
            
            y=x[index]
           
            y = self.normal_layer(y)
            
            
            
            y,_= self.bilstm(y)
            
            
            #y = y + output #following the paper instructions
            
            #y=y.squeeze(0)
           
            y=y.permute(1,0).clone()
            
            y = self.normalization_after_merge(y)

           
            
            y= self.channel_mlp(y)
            
            
            #y = y + channel_output  #following the paper instructions
            
            outputs.append(y)
       
        return outputs


In [2]:
import torch.nn.init as init

class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        
        self.num_patches = (image_size // patch_size) ** 2 #compute the number of patches in such a way they won't overlap

        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) #convolutional layer takes the input image, divides it into patches, and embeds each patch into a lower-dimensional representation.
        self.bn = nn.BatchNorm2d(embed_dim)
        
    def forward(self, x):
        x = self.projection(x)
        B, C, H, W = x.size()
        x = x.reshape(B, self.num_patches, C, -1).permute(0, 2, 1, 3).clone()
        x = x.flatten(2).transpose(1, 2)
        return x

    def output_dimension(self):
        return self.embed_dim * self.num_patches


In [3]:
class PWLinearLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(PWLinearLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features)
        self.bn = nn.BatchNorm1d(out_features)
        
    def forward(self, input_list):
        # output_list = []
        # for input_tensor in input_list:
        #     output_tensor = self.linear(input_tensor)
        #     output_list.append(output_tensor)
        # stacked_output = torch.stack(output_list, dim=0)
        
        stacked_output = torch.cat(input_list, dim=0)
        # Apply linear transformation
        print(stacked_output.shape)
        output_tensor = self.linear(stacked_output)
        return output_tensor


In [4]:
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor):
        super(PatchMerging, self).__init__()
        self.scale_factor = scale_factor
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm1d(out_channels)
       

    def forward(self, patch_list):
        
        x = torch.cat(patch_list, dim=1).permute(1,0).clone() # compination of the list of patches into a single tensor + peermutation to fit convolutional layer
        
        x=x.unsqueeze(0).clone()
        
        x=x.permute(1,2,0).clone()
        x = self.conv(x) # Apply to the input x, a 1x1 convolution to merge the patches together
        
        
        x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode='nearest') #resize the feature map using a nearest-neighbor upsampling
        
        return x


In [5]:
class GlobalAveragePooling(nn.Module):
    def __init__(self):
        super(GlobalAveragePooling, self).__init__()
        
        
    def forward(self, x):
        print("bo global")
        
        x = torch.mean(x, dim=(-2, -1))# perform global average pooling along spatial dimensions using the mean
        return x


In [6]:

import torch.nn.init as init
import torch.nn.functional as F
class VanillaSequencerBlockModel(nn.Module):
    def __init__(self, num_classes, in_channels):
        super(VanillaSequencerBlockModel, self).__init__()

        self.num_classes = num_classes
        

       
        self.patch_embedding_1 = PatchEmbedding(32, 8, in_channels, 16)#  patch embedding with an 8x8 kernel size for each patch
        self.ln_1 = nn.LayerNorm(16)

        
        self.sequencer_block_1 =  nn.Sequential(
            VanillaSequencerBlock(16, 48, 16, 96),
            VanillaSequencerBlock(96, 96, 96, 192),
            VanillaSequencerBlock(192, 192, 192,384),
            VanillaSequencerBlock(384, 192, 384,384)
        )


        
        self.patch_merging=PatchMerging(49152,128,2)

      
        self.sequencer_block_2 =  nn.Sequential(
            VanillaSequencerBlock(384, 192, 384,384)
            #,
            #VanillaSequencerBlock(384, 192, 384,384),
            #VanillaSequencerBlock(384, 192, 384,384)
        )

       
        self.pw_linear_1 = PWLinearLayer( 384,384)
        
        
        self.sequencer_block_3 =  nn.Sequential(
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384)
            
           
        )

        self.pw_linear_2 = PWLinearLayer(384, 384)

       
        self.sequencer_block_4 =  nn.Sequential(
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384)
        )
        self.pw_linear_3 = PWLinearLayer(384, 384)


       
        self.ln_2 = nn.LayerNorm(384)

        
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 384))

        
        self.fc = nn.Linear(384, num_classes)
        
        
        

    def forward(self, x):
       
       
        x = x.permute(0,3,1,2).clone()#compute permutation
        x = self.patch_embedding_1(x)
        
        x = self.ln_1(x)
        
        x = F.relu(x)
        
        #x=x.permute(0,2,1)
        x = self.sequencer_block_1(x)
       
        
        x = self.patch_merging(x)
        
        x=x.permute(0,2,1).clone()
        
       #x = self.sequencer_block_2(x)

        
        #x = self.pw_linear_1(x)
        
        x = F.relu(x)
        
       
        #x = self.pw_linear_3(x)

       
        x = self.ln_2(x)
        
        x=x.unsqueeze(0).clone()
        x = self.global_avg_pool(x)
        
        x = self.fc(x)
     
        x=x.view(128,4)
        return x



In [7]:
from torchvision import datasets
import numpy as np
from torch.utils.data import Dataset

class CustomCIFAR2(Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(CustomCIFAR2, self).__init__()
        self.cifar10 = datasets.CIFAR10(root, train=train, transform=transform, target_transform=target_transform, download=download)
        
        
        self.keep_classes = [0, 1, 2, 3]  
        self.data, self.targets = self.filter_classes()

    def filter_classes(self):
        mask = np.isin(self.cifar10.targets, self.keep_classes)
        data = [self.cifar10.data[i] for i, include in enumerate(mask) if include]
        targets = [self.cifar10.targets[i] for i, include in enumerate(mask) if include]
        return data, targets

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        return img, target

    def __len__(self):
        return len(self.data)


In [8]:
import torch.optim as optim


import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VanillaSequencerBlockModel(num_classes=4, in_channels=3)


model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)



batch_size = 128

num_epochs = 10

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),   # Randomly flip the image horizontally
    transforms.RandomRotation(15),      # Randomly rotate the image by up to 15 degrees
   
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),# Adjust brightness, contrast, saturation, and hue
    transforms.RandomResizedCrop(16),
    transforms.RandomResizedCrop(4),# Randomly crop and resize the image to 224x224
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.RandomAffine(degrees=4, translate=(0.4, 0.1)),# Randomly translate the image
    transforms.ToTensor(),              # Convert the image to a tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image

])



#train_dataset = torchvision.datasets.CIFAR10(root='./cifar-10-batches-py', train=True, transform=transform)
custom_dataset = CustomCIFAR2(root='./cifar-10-batches-py', train=True, transform=transform, download=True)
train_dataset, val_dataset=train_test_split(custom_dataset,test_size=0.2, random_state=42)

data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)



Files already downloaded and verified


In [9]:
import torch.nn.functional as F
all_accuracy=[]
all_loss=[]
# beginning of training
best_v_loss=float('inf')
val_iter=iter(val_loader)

#try:
for epoch in range(num_epochs):
    #for epoch in range(epoch +1 , epoch + num_epochs+1):
        total_correct = 0
        total_samples = 0
        v_total_sample=0
        v_total_corr=0
        running_loss = 0.0  # take track of the loss 
        
        model.train()

        for i, data in enumerate(data_loader, 0):
            if i==195:
                continue
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
        
            optimizer.zero_grad()
            
            inputs=inputs.to(torch.float32)
            outputs = model(inputs)
            
            

            
            loss = criterion(outputs, labels)
            all_loss.append(loss)
            
        
            loss.backward()
            optimizer.step()

            # Print or log gradients
            # print("Gradients (max/min/mean):")
            # for name, param in model.named_parameters():
            #     if param.grad is not None:
            #         print(f"{name}: {param.grad.max().item():.6f} / {param.grad.min().item():.6f} / {param.grad.mean().item():.6f}")
            
            running_loss += loss.item()
            predicted_probabilities = outputs.argmax(dim=1)
            print(predicted_probabilities)
            correct = (predicted_probabilities == labels).sum().item()
            total_correct += correct
            total_samples += labels.size(0)

            batch_accuracy = (correct / labels.size(0)) * 100.0
            all_accuracy.append(batch_accuracy)
            
            
            
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i}/{len(data_loader)}] Loss: {loss.item():.4f} Accuracy: {batch_accuracy:.2f}%")
            if i % 200 == 199: 
                print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i}/{len(data_loader)}] Loss: {loss.item():.4f} Accuracy: {batch_accuracy:.2f}%")
                print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}")
                running_loss = 0.0
            #torch.cuda.empty_cache(0)
                
            if i % 2 ==1:
                model.eval()
                v_loss=0.0
                correct_v_predictions=0
                
                all_val_acc=[]
                with torch.no_grad():
                    val_data=next(val_iter,None)
                    if val_data is None:
                        val_iter=iter(val_loader)
                        val_data=next(val_iter)
                    v_inputs,v_labels=val_data
                    v_inputs, v_labels= v_inputs.to(device), v_labels.to(device)
                    v_inputs=v_inputs.to(torch.float32)
                    v_output=model(v_inputs)
                    loss_val = criterion(v_output, v_labels)
                    v_loss += loss_val.item()

                    _,predicted_lab= torch.max(v_output,1)
                    correct_v_predictions+= (predicted_lab==v_labels).sum().item()
                    v_total_corr+=correct_v_predictions
                    v_total_sample=v_labels.size(0)

                    val_accuracy= (correct_v_predictions / v_labels.size(0))*100.0
                    all_val_acc.append(val_accuracy)
                    print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i}/{len(val_loader)}] V_Loss: {loss_val.item():.4f} val_Accuracy: {val_accuracy:.2f}%")
                
       
        epoch_accuracy = (total_correct / total_samples) * 100.0
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                # Add any other relevant information
            }, 'checkpointVanilla1.pth')
            
        print(f"Epoch [{epoch+1}/{num_epochs}] Accuracy: {epoch_accuracy:.2f}%")

print("Finished Training")
#except Exception as e:
#    print(f"Error: {e}")

In [None]:
checkpoint_path= 'checkpointVanilla1.pth'
checkpoint= torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch= checkpoint['epoch']
loss= checkpoint['loss']

In [None]:
import pandas
from sklearn import model_selection
from sklearn.linear_model import LogisticRegression
import pickle

model_pkl_file = "vanillasequencer1_model384.pkl"

with open(model_pkl_file, 'wb') as file:  
    pickle.dump(model, file)

In [None]:
test_dataset = CustomCIFAR2(root='./cifar-10-batches-py', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified


In [None]:
model.eval()

test_correct=0
test_total=0
test_loss=0
test_loss_total=0
i=0
with torch.no_grad():
    for test_data in test_loader:
        inputs, labels= test_data 
        inputs, labels= inputs.to(device), labels.to(device)
        i+=1
        inputs=inputs.to(torch.float32)
        test_outputs=model(inputs)
        test_loss =criterion(test_outputs, labels)
        test_loss_total+=test_loss
        _, test_predicted= torch.max(test_outputs, 1)
        test_total += labels.size(0)
        test_correct += (test_predicted == labels).sum().item() 

        print(f"Batch [{i+1}/{len(test_loader)}] Test loss: {test_loss.item():.4f}")
test_accuracy = (test_correct / test_total) * 100.0
test_loss /= len(test_loader)

print(f'Test Loss: {test_loss: .4f} Test Accuracy: {test_accuracy: .2f}%')

Batch [2/32] Test loss: 1.3818
Batch [3/32] Test loss: 1.3930
Batch [4/32] Test loss: 1.3805
Batch [5/32] Test loss: 1.3930
Batch [6/32] Test loss: 1.3939
Batch [7/32] Test loss: 1.3957
Batch [8/32] Test loss: 1.4009
Batch [9/32] Test loss: 1.3760
Batch [10/32] Test loss: 1.3920
Batch [11/32] Test loss: 1.3850
Batch [12/32] Test loss: 1.3911
Batch [13/32] Test loss: 1.3892
Batch [14/32] Test loss: 1.3818
Batch [15/32] Test loss: 1.3893
Batch [16/32] Test loss: 1.3912
Batch [17/32] Test loss: 1.3920
Batch [18/32] Test loss: 1.3843
Batch [19/32] Test loss: 1.3869
Batch [20/32] Test loss: 1.3878
Batch [21/32] Test loss: 1.3820
Batch [22/32] Test loss: 1.3812
Batch [23/32] Test loss: 1.4016
Batch [24/32] Test loss: 1.3875
Batch [25/32] Test loss: 1.3865
Batch [26/32] Test loss: 1.3878
Batch [27/32] Test loss: 1.3880
Batch [28/32] Test loss: 1.3905
Batch [29/32] Test loss: 1.3843
Batch [30/32] Test loss: 1.3791
Batch [31/32] Test loss: 1.3917
Batch [32/32] Test loss: 1.4010


RuntimeError: Given groups=1, weight of size [128, 49152, 1, 1], expected input[1, 12288, 384, 1] to have 49152 channels, but got 12288 channels instead