# What is it about?

## Just pure code for transformer usage. Nothing else. 

In [None]:
class PatchEncoder(nn.Module):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        
        self.position_embedding = nn.Embedding(num_patches, projection_dim)
        self.register_buffer('positions', torch.arange(num_patches).unsqueeze(0))
        

        
    def forward(self, patch):
        
        encoded = patch + self.position_embedding(self.positions)
        return encoded
    
    def get_config(self):
        config = {"num_patches": self.num_patches}
        return config

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_units, dropout_rate):
        super(MLP, self).__init__()
        layers = []
        in_dim = input_dim
        
        for units in hidden_units:
            layers.append(nn.Linear(in_dim, units))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout_rate))
            in_dim = units
        
        self.mlp = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.mlp(x)
    
class TransformerBlock(nn.Module):
    def __init__(self, projection_dim, num_heads, transformer_units, dropout_rate=0.01):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(projection_dim, eps=1e-6)
        self.attention = nn.MultiheadAttention(projection_dim, num_heads, dropout=dropout_rate)
        self.norm2 = nn.LayerNorm(projection_dim, eps=1e-6)
        self.mlp = MLP(projection_dim, transformer_units, dropout_rate)
    
    def forward(self, x):
        x1 = self.norm1(x)
        attn_output, _ = self.attention(x1, x1, x1)
        x2 = x + attn_output
        x3 = self.norm2(x2)
        x3 = self.mlp(x3)
        return x2 + x3

In [None]:
class AdvancedTransformerModel(nn.Module):
    def __init__(self, color_size, projection_dim, num_heads, num_layers, dropout_rate=0.1):
        super(AdvancedTransformerModel, self).__init__()
        
        self.projection_dim = projection_dim
        self.embedding = nn.Linear(color_size, projection_dim)      
        self.positional_encoder = PatchEncoder(color_size, projection_dim)
        self.transformer_units = [projection_dim * 2, projection_dim]
        self.transformer_blocks = nn.ModuleList([TransformerBlock(projection_dim, num_heads, self.transformer_units) for _ in range(num_layers)])                
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
        
        self.dense_v1 = nn.Linear(projection_dim, projection_dim)
        self.dense_v2 = nn.Linear(projection_dim, projection_dim)
        self.dense_v3 = nn.Linear(projection_dim, 1)
        
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):

        if x.dtype == torch.uint8:
            x = x.to(torch.float32)

        x = self.embedding(x).unsqueeze(1)
        x = self.positional_encoder(x)
        
        for block in self.transformer_blocks:
            x = block(x)
                
        x = x.permute(0, 2, 1)
        
        x = self.global_avg_pool(x).squeeze(-1)
        
        v = self.dense_v1(x)
        v = self.dropout(v)
        v = self.dense_v2(v)
        v = self.dropout(v)
        v = self.dense_v3(v)
        
        return v.flatten()

In [None]:
model = AdvancedTransformerModel(color_size = n, projection_dim = param[0], num_heads = 4, num_layers = param[1], dropout_rate=0.00).to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
checkpoint = "copy/path/to/checkpoint.pth"

In [None]:
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
criterion = checkpoint["loss"]
losses = checkpoint['losses']

# If you want to train

In [None]:
def trainer(model, dataset_size, n_epochs, once=32, verbose=True, optimizer=None, criterion=None):
    model.train()

    criterion = nn.MSELoss() if criterion is None else criterion
    optimizer = optim.Adam(model.parameters(), lr=1e-6) if optimizer is None else optimizer

    losses = []
   
    for epoch in range(n_epochs):
        running_loss = 0.0
        mse_sum = 0.0

        # if epoch % 50 == 0:
        #     checkpoint = {
        #         'epoch': epoch,               
        #         'model_state_dict': model.state_dict(),
        #         'optimizer_state_dict': optimizer.state_dict(), 
        #         'loss': criterion,                       
        #     }
            
        #     torch.save(checkpoint, 'checkpoint.pth')

        X, y = cayley_group_ex.random_walks(CFG['n_steps_limit'], dataset_size//200)

        for batch in range(dataset_size//once):
            
            X_, y_ = X[batch*once:(batch+1)*once].float().to(device), y[batch*once:(batch+1)*once].float().to(device)

            preds = model(X_)
            loss = criterion(preds, y_)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            mse_sum += ((preds - y_) ** 2).sum().item()

        if verbose == True:
            train_loss = running_loss / dataset_size
            train_mse = mse_sum / dataset_size
            train_rmse = train_mse ** 0.5

            print(f'Epoch {epoch}/{n_epochs}: loss: {train_loss};')
            print(f'MSE: {train_mse} RMSE: {train_rmse}')
            print()

            losses.append(train_loss)

    if verbose == True:
        del X
        del y
        del optimizer
        del criterion
        del loss
        del train_loss
        del train_mse
        del train_rmse

        torch.cuda.empty_cache()
                
    return model, losses