In [2]:
import math
from enum import Enum

import pytorch_lightning as pl
import torch
from torch import nn

from xformers.factory import xFormer, xFormerConfig

from torch.utils.tensorboard import SummaryWriter  # type: ignore

from gdeep.data import OrbitsGenerator, DataLoaderKwargs
from gdeep.pipeline import Pipeline

# %%
print("blub")

class Classifier(str, Enum):
    GAP = "gap"
    TOKEN = "token"

class SetTransformer(pl.LightningModule):
    def __init__(
        self,
        steps,
        learning_rate=1e-2,
        weight_decay=0.0001,
        image_size=32,
        num_classes=10,
        patch_size=4,
        dim=256,
        n_layer=12,
        n_head=8,
        resid_pdrop=0.1,
        attn_pdrop=0.1,
        mlp_pdrop=0.1,
        attention="scaled_dot_product",
        hidden_layer_multiplier=4,
        linear_warmup_ratio=0.05,
        seq_len=1_000,
        classifier: Classifier = Classifier.GAP,
    ):

        super().__init__()

        # all the inputs are saved under self.hparams (hyperparams)
        self.save_hyperparameters()

        # A list of the encoder or decoder blocks which constitute the Transformer.
        xformer_config = [
            {
                "block_config": {
                    "block_type": "encoder",
                    "num_layers": n_layer,
                    "dim_model": dim,
                    "seq_len": seq_len,
                    "layer_norm_style": "pre",
                    "multi_head_config": {
                        "num_heads": n_head,
                        "residual_dropout": resid_pdrop,
                        "use_rotary_embeddings": False,
                        "attention": {
                            "name": attention,
                            "dropout": attn_pdrop,
                            "causal": False,
                        },
                    },
                    "feedforward_config": {
                        "name": "MLP",
                        "dropout": mlp_pdrop,
                        "activation": "gelu",
                        "hidden_layer_multiplier": hidden_layer_multiplier,
                    },
                }
            }
        ]

        config = xFormerConfig(xformer_config)
        self.transformer = xFormer.from_config(config)

        self.patch_emb = nn.Linear(2, dim)

        if classifier == Classifier.TOKEN:
            self.clf_token = nn.Parameter(torch.zeros(dim))

        self.ln = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)
        self.criterion = torch.nn.CrossEntropyLoss()

    @staticmethod
    def linear_warmup_cosine_decay(warmup_steps, total_steps):
        """
        Linear warmup for warmup_steps, with cosine annealing to 0 at total_steps
        """

        def fn(step):
            if step < warmup_steps:
                return float(step) / float(max(1, warmup_steps))

            progress = float(step - warmup_steps) / float(
                max(1, total_steps - warmup_steps)
            )
            return 0.5 * (1.0 + math.cos(math.pi * progress))

        return fn

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams.learning_rate,
            momentum=0.9,
            weight_decay=self.hparams.weight_decay,
        )

        warmup_steps = int(self.hparams.linear_warmup_ratio * self.hparams.steps)

        scheduler = {
            "scheduler": torch.optim.lr_scheduler.LambdaLR(
                optimizer,
                self.linear_warmup_cosine_decay(warmup_steps, self.hparams.steps),
            ),
            "interval": "step",
        }

        return [optimizer], [scheduler]

    def forward(self, x):
        batch, *_ = x.shape  # BCHW

        x = self.patch_emb(x)

        # flatten patches into sequence
        #x = x.flatten(2, 3).transpose(1, 2).contiguous()  # B HW C

        if self.hparams.classifier == Classifier.TOKEN:
            # prepend classification token
            clf_token = (
                torch.ones(1, batch, self.hparams.dim, device=x.device) * self.clf_token
            )
            x = torch.cat([clf_token, x[:-1, :, :]], axis=0)

        x = self.transformer(x)
        x = self.ln(x)

        if self.hparams.classifier == Classifier.TOKEN:
            x = x[:, 0]
        elif self.hparams.classifier == Classifier.GAP:
            x = x.mean(dim=1)  # mean over sequence len

        x = self.head(x)
        return x
# %%
model = SetTransformer(steps=500,
                       num_classes=5,
                       dim=32,
                       n_layer=2,
                       n_head=4,
                       learning_rate=1e-4,
                       attention="scaled_dot_product")

# %%
homology_dimensions = (0, 1)

dataloaders_dicts = DataLoaderKwargs(train_kwargs = {"batch_size": 64},
                                     val_kwargs = {"batch_size": 4},
                                     test_kwargs = {"batch_size": 3})

og = OrbitsGenerator(num_orbits_per_class=1_000,
                     homology_dimensions = homology_dimensions,
                     validation_percentage=0.0,
                     test_percentage=0.0,
                     n_jobs=2
                     #dynamical_system = 'pp_convention'
                     )

dl_train, _, _ = og.get_dataloader_orbits(dataloaders_dicts)

# %%

loss_fn = model.criterion

# Initialize the Tensorflow writer
writer = SummaryWriter()

# initialise pipeline class
pipe = Pipeline(model, [dl_train, None], loss_fn, writer)
# %%
optimizer_list, scheduler_list = model.configure_optimizers()

# %%
warmup_steps = int(model.hparams.linear_warmup_ratio * model.hparams.steps)

print('warmup steps:', warmup_steps)

# train the model
pipe.train(torch.optim.Adam, model.hparams.steps, cross_validation=False,
        optimizers_param={"lr": model.hparams.learning_rate,
                          "weight_decay":model.hparams.weight_decay,},
                          lr_scheduler=torch.optim.lr_scheduler.LambdaLR,
                          scheduler_params={'lr_lambda': model.linear_warmup_cosine_decay(warmup_steps, model.hparams.steps)}
                          )
# %%

blub
warmup steps: 25
Epoch 1
-------------------------------
No TPUs
Batch training loss:  1.7423405495900957  	Batch training accuracy:  19.990079365079367  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 18.945312%,                 Avg loss: 1.764437 

Epoch 2
-------------------------------
No TPUs
Batch training loss:  1.722286131646898  	Batch training accuracy:  19.990079365079367  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 18.945312%,                 Avg loss: 1.720310 

Epoch 3
-------------------------------
No TPUs
Batch training loss:  1.6773373153474596  	Batch training accuracy:  19.990079365079367  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 18.945312%,                 Avg loss: 1.661255 

Epoch 4
-------------------------------
No TPUs
Batch training loss:  1.6362711758840651  	Batch training ac

Batch training loss:  1.6108111975684998  	Batch training accuracy:  18.874007936507937  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 20.214844%,                 Avg loss: 1.608810 

Epoch 30
-------------------------------
No TPUs
Batch training loss:  1.6116391287909613  	Batch training accuracy:  18.328373015873016  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 19.628906%,                 Avg loss: 1.611108 

Epoch 31
-------------------------------
No TPUs
Batch training loss:  1.6113781872249784  	Batch training accuracy:  19.518849206349206  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 19.238281%,                 Avg loss: 1.611942 

Epoch 32
-------------------------------
No TPUs
Batch training loss:  1.6108449678572396  	Batch training accuracy:  17.509920634920633  	[ 63 / 63 ]                     
Time

Batch training loss:  1.6111948736130246  	Batch training accuracy:  20.13888888888889  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 20.214844%,                 Avg loss: 1.610505 

Epoch 58
-------------------------------
No TPUs
Batch training loss:  1.6108206491621713  	Batch training accuracy:  19.444444444444446  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 19.335938%,                 Avg loss: 1.610511 

Epoch 59
-------------------------------
No TPUs
Batch training loss:  1.6101361286072504  	Batch training accuracy:  19.866071428571427  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 22.558594%,                 Avg loss: 1.609941 

Epoch 60
-------------------------------
No TPUs
Batch training loss:  1.6106871972008356  	Batch training accuracy:  18.278769841269842  	[ 63 / 63 ]                     
Time

Batch training loss:  1.6112266778945923  	Batch training accuracy:  19.196428571428573  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 18.945312%,                 Avg loss: 1.613167 

Epoch 86
-------------------------------
No TPUs
Batch training loss:  1.6103358249815682  	Batch training accuracy:  21.205357142857142  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 19.238281%,                 Avg loss: 1.610063 

Epoch 87
-------------------------------
No TPUs
Batch training loss:  1.610704919648549  	Batch training accuracy:  18.452380952380953  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 16.308594%,                 Avg loss: 1.609281 

Epoch 88
-------------------------------
No TPUs
Batch training loss:  1.6103872268918962  	Batch training accuracy:  19.791666666666664  	[ 63 / 63 ]                     
Time

Batch training loss:  1.6108635001712375  	Batch training accuracy:  19.09722222222222  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 18.457031%,                 Avg loss: 1.609668 

Epoch 114
-------------------------------
No TPUs
Batch training loss:  1.6098086587966434  	Batch training accuracy:  18.998015873015873  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 19.824219%,                 Avg loss: 1.609684 

Epoch 115
-------------------------------
No TPUs
Batch training loss:  1.6100333796607122  	Batch training accuracy:  19.46924603174603  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 21.679688%,                 Avg loss: 1.609461 

Epoch 116
-------------------------------
No TPUs
Batch training loss:  1.610615329136924  	Batch training accuracy:  19.816468253968253  	[ 63 / 63 ]                      
T

Batch training loss:  1.609273081734067  	Batch training accuracy:  19.270833333333336  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 21.289062%,                 Avg loss: 1.608959 

Epoch 142
-------------------------------
No TPUs
Batch training loss:  1.6087882102481903  	Batch training accuracy:  17.708333333333336  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 20.117188%,                 Avg loss: 1.609786 

Epoch 143
-------------------------------
No TPUs
Batch training loss:  1.6086567678148784  	Batch training accuracy:  20.411706349206348  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 20.507812%,                 Avg loss: 1.609352 

Epoch 144
-------------------------------
No TPUs
Batch training loss:  1.6081218038286482  	Batch training accuracy:  20.7093253968254  	[ 63 / 63 ]                       
T

Batch training loss:  1.58872385819753  	Batch training accuracy:  21.949404761904763  	[ 63 / 63 ]                       
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 21.289062%,                 Avg loss: 1.588373 

Epoch 170
-------------------------------
No TPUs
Batch training loss:  1.5875504338552082  	Batch training accuracy:  24.00793650793651  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 24.511719%,                 Avg loss: 1.582751 

Epoch 171
-------------------------------
No TPUs
Batch training loss:  1.585250108961075  	Batch training accuracy:  23.487103174603174  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 25.292969%,                 Avg loss: 1.580344 

Epoch 172
-------------------------------
No TPUs
Batch training loss:  1.5834987163543701  	Batch training accuracy:  26.53769841269841  	[ 63 / 63 ]                      
T

Batch training loss:  1.5110884147977073  	Batch training accuracy:  37.92162698412698  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 33.496094%,                 Avg loss: 1.500320 

Epoch 198
-------------------------------
No TPUs
Batch training loss:  1.5110740245334686  	Batch training accuracy:  35.1438492063492  	[ 63 / 63 ]                       
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 30.761719%,                 Avg loss: 1.493161 

Epoch 199
-------------------------------
No TPUs
Batch training loss:  1.5056540379448542  	Batch training accuracy:  38.913690476190474  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 27.246094%,                 Avg loss: 1.504759 

Epoch 200
-------------------------------
No TPUs
Batch training loss:  1.5036460142286996  	Batch training accuracy:  36.507936507936506  	[ 63 / 63 ]                     
T

Batch training loss:  1.367854939566718  	Batch training accuracy:  42.06349206349206  	[ 63 / 63 ]                       
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 37.500000%,                 Avg loss: 1.385874 

Epoch 226
-------------------------------
No TPUs
Batch training loss:  1.3584657756109086  	Batch training accuracy:  42.28670634920635  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 35.449219%,                 Avg loss: 1.391989 

Epoch 227
-------------------------------
No TPUs
Batch training loss:  1.3563131756252713  	Batch training accuracy:  41.76587301587302  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 37.402344%,                 Avg loss: 1.391601 

Epoch 228
-------------------------------
No TPUs
Batch training loss:  1.354501179286412  	Batch training accuracy:  40.99702380952381  	[ 63 / 63 ]                       
T

Batch training loss:  1.2836703156668043  	Batch training accuracy:  41.66666666666667  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 37.500000%,                 Avg loss: 1.382723 

Epoch 254
-------------------------------
No TPUs
Batch training loss:  1.271795392036438  	Batch training accuracy:  43.67559523809524  	[ 63 / 63 ]                       
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 40.820312%,                 Avg loss: 1.349189 

Epoch 255
-------------------------------
No TPUs
Batch training loss:  1.2683561245600383  	Batch training accuracy:  43.37797619047619  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 40.234375%,                 Avg loss: 1.277871 

Epoch 256
-------------------------------
No TPUs
Batch training loss:  1.2706073711788843  	Batch training accuracy:  43.57638888888889  	[ 63 / 63 ]                      
T

Batch training loss:  1.2113306579135714  	Batch training accuracy:  45.56051587301587  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 40.820312%,                 Avg loss: 1.315258 

Epoch 282
-------------------------------
No TPUs
Batch training loss:  1.2096924819643535  	Batch training accuracy:  45.03968253968254  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 42.773438%,                 Avg loss: 1.254268 

Epoch 283
-------------------------------
No TPUs
Batch training loss:  1.2007965511745877  	Batch training accuracy:  46.47817460317461  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 44.433594%,                 Avg loss: 1.237540 

Epoch 284
-------------------------------
No TPUs
Batch training loss:  1.2024350033866034  	Batch training accuracy:  46.1061507936508  	[ 63 / 63 ]                       
T

Batch training loss:  1.146445667932904  	Batch training accuracy:  48.61111111111111  	[ 63 / 63 ]                       
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 48.339844%,                 Avg loss: 1.143879 

Epoch 310
-------------------------------
No TPUs
Batch training loss:  1.1493982284788102  	Batch training accuracy:  48.214285714285715  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 46.679688%,                 Avg loss: 1.207866 

Epoch 311
-------------------------------
No TPUs
Batch training loss:  1.1431206899975974  	Batch training accuracy:  48.660714285714285  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 48.144531%,                 Avg loss: 1.126849 

Epoch 312
-------------------------------
No TPUs
Batch training loss:  1.143893976060171  	Batch training accuracy:  48.735119047619044  	[ 63 / 63 ]                      
T

Batch training loss:  1.102451209038023  	Batch training accuracy:  50.86805555555556  	[ 63 / 63 ]                       
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 53.027344%,                 Avg loss: 1.087565 

Epoch 338
-------------------------------
No TPUs
Batch training loss:  1.1004516132294186  	Batch training accuracy:  51.11607142857143  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 50.585938%,                 Avg loss: 1.119787 

Epoch 339
-------------------------------
No TPUs
Batch training loss:  1.092538821318793  	Batch training accuracy:  51.78571428571429  	[ 63 / 63 ]                       
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 50.976562%,                 Avg loss: 1.100046 

Epoch 340
-------------------------------
No TPUs
Batch training loss:  1.0970324996917966  	Batch training accuracy:  50.81845238095239  	[ 63 / 63 ]                      
T

Batch training loss:  1.0729609065585666  	Batch training accuracy:  52.05853174603175  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 54.882812%,                 Avg loss: 1.043704 

Epoch 366
-------------------------------
No TPUs
Batch training loss:  1.0709298101682512  	Batch training accuracy:  51.587301587301596  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 53.710938%,                 Avg loss: 1.055636 

Epoch 367
-------------------------------
No TPUs
Batch training loss:  1.0698156479805234  	Batch training accuracy:  52.182539682539684  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 54.003906%,                 Avg loss: 1.054881 

Epoch 368
-------------------------------
No TPUs
Batch training loss:  1.0665564262677754  	Batch training accuracy:  51.5625  	[ 63 / 63 ]                                
T

Batch training loss:  1.0499258826649378  	Batch training accuracy:  52.67857142857143  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 54.882812%,                 Avg loss: 1.031872 

Epoch 394
-------------------------------
No TPUs
Batch training loss:  1.0486069075644961  	Batch training accuracy:  52.28174603174603  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 54.882812%,                 Avg loss: 1.033020 

Epoch 395
-------------------------------
No TPUs
Batch training loss:  1.0517801680262127  	Batch training accuracy:  52.951388888888886  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 55.468750%,                 Avg loss: 1.030393 

Epoch 396
-------------------------------
No TPUs
Batch training loss:  1.0472052475762745  	Batch training accuracy:  52.38095238095239  	[ 63 / 63 ]                      
T

Batch training loss:  1.0398112611165122  	Batch training accuracy:  53.199404761904766  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 55.761719%,                 Avg loss: 1.009084 

Epoch 422
-------------------------------
No TPUs
Batch training loss:  1.0421619074685233  	Batch training accuracy:  53.000992063492056  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 55.957031%,                 Avg loss: 1.003658 

Epoch 423
-------------------------------
No TPUs
Batch training loss:  1.0364104443126254  	Batch training accuracy:  53.050595238095234  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 55.664062%,                 Avg loss: 1.025253 

Epoch 424
-------------------------------
No TPUs
Batch training loss:  1.031600832939148  	Batch training accuracy:  53.050595238095234  	[ 63 / 63 ]                      
T

Batch training loss:  1.0317579214535062  	Batch training accuracy:  52.653769841269835  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 56.347656%,                 Avg loss: 1.006285 

Epoch 450
-------------------------------
No TPUs
Batch training loss:  1.0313678658197796  	Batch training accuracy:  53.050595238095234  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 56.152344%,                 Avg loss: 1.012680 

Epoch 451
-------------------------------
No TPUs
Batch training loss:  1.0289982746517847  	Batch training accuracy:  52.75297619047619  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 56.250000%,                 Avg loss: 1.016863 

Epoch 452
-------------------------------
No TPUs
Batch training loss:  1.026213530510191  	Batch training accuracy:  53.125  	[ 63 / 63 ]                                  
T

Batch training loss:  1.02640938569629  	Batch training accuracy:  53.57142857142857  	[ 63 / 63 ]                        
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 56.445312%,                 Avg loss: 1.001639 

Epoch 478
-------------------------------
No TPUs
Batch training loss:  1.0199011365572612  	Batch training accuracy:  53.79464285714286  	[ 63 / 63 ]                      
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 56.738281%,                 Avg loss: 1.004231 

Epoch 479
-------------------------------
No TPUs
Batch training loss:  1.0237681761620536  	Batch training accuracy:  53.596230158730165  	[ 63 / 63 ]                     
Time taken for this epoch: 6s
No TPUs
Validation results: 
 Accuracy: 56.835938%,                 Avg loss: 1.008986 

Epoch 480
-------------------------------
No TPUs
Batch training loss:  1.0249117830443004  	Batch training accuracy:  52.90178571428571  	[ 63 / 63 ]                      
T

(1.0020307675004005, 56.73828125)

In [3]:
model

SetTransformer(
  (transformer): xFormer(
    (encoders): ModuleList(
      (0): xFormerEncoderBlock(
        (mha): MultiHeadDispatch(
          (attention): ScaledDotProduct(
            (attn_drop): Dropout(p=0.1, inplace=False)
          )
          (in_proj_container): InProjContainer()
          (resid_drop): Dropout(p=0.1, inplace=False)
          (proj): Linear(in_features=32, out_features=32, bias=True)
        )
        (feedforward): MLP(
          (mlp): Sequential(
            (0): Linear(in_features=32, out_features=128, bias=True)
            (1): GELU()
            (2): Dropout(p=0.1, inplace=False)
            (3): Linear(in_features=128, out_features=32, bias=True)
            (4): Dropout(p=0.1, inplace=False)
          )
        )
        (wrap_att): Residual(
          (layer): PreNorm(
            (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
            (sublayer): MultiHeadDispatch(
              (attention): ScaledDotProduct(
                (att