# X' (X Prime) Experiment 
- Adding the location channels in the beginning of the image, and not at each layer step. 

In [None]:
# Torch
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch import optim 


# Train + Data 
import sys 
sys.path.append('../Layers')
from Conv1d_NN_spatial import * 
from Conv2d_NN_spatial import * 

sys.path.append('../Data')
from CIFAR10 import * 


sys.path.append('../Models')
from models_2d import *

sys.path.append('../Train')
from train2d import * 


In [None]:
cifar10 = CIFAR10()

In [None]:
## CNN
class CNN(nn.Module):
    def __init__(self, in_ch=3, num_classes=10, kernel_size=3):
        super(CNN, self).__init__()
        
        
        self.conv1 = nn.Conv2d(in_ch+3, 16, kernel_size=kernel_size, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=kernel_size, stride=1, padding=1)

        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(32768, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

        self.relu = nn.ReLU()
        self.to("mps")
        self.name = "CNN"

    def forward(self, x):
        
        
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.flatten(x)

        x = self.relu(self.fc1(x))
        x = self.fc2(x)

        return x
    
    def summary(self, input_size = (3, 32, 32)): 
        self.to("cpu")
        print(summary(self, input_size))
        self.to("mps")

In [None]:
class ConvNN_2D_K_All(nn.Module):
    def __init__(self, in_ch=3, num_classes=10, K=9):
        super(ConvNN_2D_K_All, self).__init__()
        
        self.conv1 = Conv2d_NN(in_ch+3, 16, K=K, stride=K, shuffle_pattern="BA", shuffle_scale=2, samples="all")
        self.conv2 = Conv2d_NN(16, 32, K=K, stride=K, shuffle_pattern="BA", shuffle_scale=2, samples="all")

        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(32768, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

        self.relu = nn.ReLU()
        self.to("mps")
        self.name = "ConvNN_2D_K_All"

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.flatten(x)

        x = self.relu(self.fc1(x))
        x = self.fc2(x)

        return x
    
    def summary(self, input_size = (3, 32, 32)): 
        self.to("cpu")
        print(summary(self, input_size))
        self.to("mps")
        
    def coordinate_channels(self, tensor_shape, device):
        x_ind = torch.arange(0, tensor_shape[2])
        y_ind = torch.arange(0, tensor_shape[3])
        
        x_grid, y_grid = torch.meshgrid(x_ind, y_ind, indexing='ij')
        
        x_grid = x_grid.float().unsqueeze(0).expand(tensor_shape[0], -1, -1).unsqueeze(1)
        y_grid = y_grid.float().unsqueeze(0).expand(tensor_shape[0], -1, -1).unsqueeze(1)
        
        xy_grid = torch.cat((x_grid, y_grid), dim=1)
        xy_grid_normalized = F.normalize(xy_grid, p=2, dim=1)
        return xy_grid_normalized.to(device)

In [None]:
class Branching_ConvNN_2D_Spatial_K_N(nn.Module):
    def __init__(self, in_ch=3, channel_ratio=(16, 16), num_classes=10, kernel_size=3, K=9, N = 8, location_channels = False)
        
        super(Branching_ConvNN_2D_Spatial_K_N, self).__init__()
        self.conv1 = ConvNN_CNN_Spatial_BranchingLayer(in_ch+3, 16, 
            channel_ratio=channel_ratio,kernel_size=kernel_size, K=K, samples=N, location_channels=location_channels)
        self.conv2 = ConvNN_CNN_Spatial_BranchingLayer(16, 32, channel_ratio=(channel_ratio[0] *2, channel_ratio[1]*2),kernel_size=kernel_size, K=K, samples=N, location_channels=location_channels)
        
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(32768, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

        self.relu = nn.ReLU()
        self.to("mps")
        self.name = "Branching_ConvNN_2D_Spatial_K_N"

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        
        x = self.flatten(x)

        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def summary(self, input_size = (3, 32, 32)): 
        self.to("cpu")
        print(summary(self, input_size))
        self.to("mps")
        
    def coordinate_channels(self, tensor_shape, device):
        x_ind = torch.arange(0, tensor_shape[2])
        y_ind = torch.arange(0, tensor_shape[3])
        
        x_grid, y_grid = torch.meshgrid(x_ind, y_ind, indexing='ij')
        
        x_grid = x_grid.float().unsqueeze(0).expand(tensor_shape[0], -1, -1).unsqueeze(1)
        y_grid = y_grid.float().unsqueeze(0).expand(tensor_shape[0], -1, -1).unsqueeze(1)
        
        xy_grid = torch.cat((x_grid, y_grid), dim=1)
        xy_grid_normalized = F.normalize(xy_grid, p=2, dim=1)
        return xy_grid_normalized.to(device)
