<a href="https://colab.research.google.com/github/gg-dema/geometric-algebraic-transformer/blob/main/baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Baseline


## DataModule

In [None]:
class PureFeaturesDataset(Dataset):
    def __init__(self, single_tensor, bifurcation_tensor):
        """
        Args:
            mv_tensor (torch.Tensor): Tensor with shape (n_items, 3, 16)
        """
        super(PureFeaturesDataset, self).__init__()
        labels_single = torch.zeros(single_tensor.size(0))
        labels_bifurcating = torch.ones(bifurcation_tensor.size(0))

        self.data = torch.cat((single_tensor, bifurcating_tensor), dim=0)
        self.data = F.normalize(self.data, p=2.0, dim=-1)
        self.labels = torch.cat((labels_single, labels_bifurcating), dim=0)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

#
class PureFeaturesDataModule(pl.LightningDataModule):
    """
    Args:
      mv_data (MultivectorDataset): dataset containing all multivectors;
      batch_size (int): size of the batches during training and testing

    Attributes:
      mv_data (MultivectorDataset): dataset containing all multivectors;
      batch_size (int): size of the batches during training and testing;
      data_train (multivector dataset): dataset for train;
      data_val (multivector dataset): dataset for validation;
      data_test (multivector dataset): dataset for test;
    """

    def __init__(self,
                 single_tensor,
                 bifurcating_tensor,
                 batch_size=32,
                 ):

      super(PureFeaturesDataModule,self).__init__()
      self.single_tensor = single_tensor
      self.bifurcating_tensor =  bifurcating_tensor
      self.batch_size = batch_size

    def setup(self, stage = None):

      if stage == 'fit' or stage is None:
            dataset = PureFeaturesDataset(self.single_tensor, self.bifurcating_tensor)
            train_len = dataset.__len__()
            val_len = int(0.1*train_len)
            val_len_split = [train_len - val_len,val_len]
            self.data_train, self.data_val = random_split(dataset, val_len_split)
      elif stage == 'test':
            self.data_test = MultivectorDataset(filename = self.test_data_name, gatr_flag = self.gatr_flag)
      elif stage == 'predict':
            self.data_test = MultivectorDataset(filename = self.test_data_name, gatr_flag = self.gatr_flag)

    def train_dataloader(self):
        return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=0)

    def val_dataloader(self):
        return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=0)

    def test_dataloader(self):
        return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=0)

    def predict_dataloader(self):
        return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=0)

## Layers

In [None]:

##Embedding

class PositionalEmbedding(nn.Module):

    def __init__(self, input_dim, seq_len=50):
        super(PositionalEmbedding, self).__init__()

        embedd = torch.zeros(seq_len, input_dim)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, input_dim, 2).float() * (-math.log(10000.0) / input_dim))
        embedd[:, 0::2] = torch.sin(position * div_term)
        embedd[:, 1::2] = torch.cos(position * div_term)
        embedd = embedd.unsqueeze(0)
        self.register_buffer('embedding', embedd)


    def forward(self, x):
        out = x + self.embedding
        return out

PositionalEmbedding(16)

PositionalEmbedding()

In [None]:
class MultiHead(nn.Module):

    def __init__(self, emb_dim, num_heads):

        super(MultiHead, self).__init__()
        assert emb_dim % num_heads == 0, "latent dimension must be divisible by the number of heads"

        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.head_dim = int(self.emb_dim/self.num_heads)
        # projection matrix in the QKV space: each matrix is emb/num_head x emb/num_head
        # Expected dimension:  64/4 : 16x16 matrix
        self.q_matrix = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.k_matrix = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.v_matrix = nn.Linear(self.head_dim, self.head_dim, bias=False)

        # output layer, take the concatenation of the multiheads output and
        # reproject with the same dimenison of the input
        self.output_projection = nn.Linear(self.num_heads*self.head_dim, self.emb_dim)

        #self._reset_parameters()


    # why we do that? why is usefull to have this data distribution?
    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.qkv_projection.weight)
        self.qkv_projection.bias.data.fill_(0)

        nn.init.xavier_uniform_(self.my_projection.weight)
        self.my_projection.bias.data.fill_(0)

    # ps: the mask is for the decoder or for hide some token
    def forward(self, k, q, v, mask=None):
        batch_size, seq_len = k.size(0), k.size(1)
        # transpose all
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Project : output :   [batch, seq_len, num_heads, head_dim]
        K = self.k_matrix(k) #  batch, 50, 4, 16
        Q = self.q_matrix(q)
        V = self.v_matrix(v)

        # for QxK.t we should re-arrange the view:
        # (batch_size, n_heads, seq_len, head_dim)
        Q = Q.transpose(1,2)  #  (32 x 4 x 50 x 16)
        K = K.transpose(1,2)  #  (32 x 4 x 50 x 16)
        V = V.transpose(1,2)  #  (32 x 4 x 50 x 16)

        K_transpose = K.transpose(-1, -2) # (batch_size, n_heads, head_dim, seq_len)
        #print(f"k_transp{K_transpose.shape}")

        score = torch.matmul(Q, K_transpose)
        # score = (batch_size, n_heads, seq_len, seq_len)
        if mask is not None:
            score = score.masked_fill(mask==0, float('-1e20'))

        attention_score = F.softmax(score/math.sqrt(self.head_dim), dim=-1)
        attention = torch.matmul(attention_score, V)
        # remove a contiuguous call after the transpose, it should just worry in case of memory proble
        concat_multihead = attention.transpose(1, 2)
        concat_multihead = concat_multihead.reshape(
            batch_size,
            seq_len,
            self.num_heads * self.head_dim)
        output = self.output_projection(concat_multihead)
        return output

MultiHead(emb_dim=64, num_heads=4)


MultiHead(
  (q_matrix): Linear(in_features=16, out_features=16, bias=False)
  (k_matrix): Linear(in_features=16, out_features=16, bias=False)
  (v_matrix): Linear(in_features=16, out_features=16, bias=False)
  (output_projection): Linear(in_features=64, out_features=64, bias=True)
)

### Utils

In [None]:

class WarmupScheduler(torch.optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, warmup, max_number_iters=-1):
        self.warmup = warmup
        self.max_number_iters = max_number_iters
        super(WarmupScheduler, self).__init__(optimizer)


    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch  / self.max_number_iters))
        if epoch < self.warmup:
            lr_factor *= epoch / self.warmup
        return lr_factor


## Model

In [None]:
class Transformer(nn.Module):

  def __init__(self, input_dim, projected_dim, num_heads=4, dropout=0.5):
    super(Transformer, self).__init__()

    self.project = nn.Linear(input_dim, projected_dim)
    self.embedding = PositionalEmbedding(projected_dim)
    self.multi_head_block = MultiHead(projected_dim, num_heads)
    self.layer_norm1 = nn.LayerNorm(projected_dim)

    self.fully_connected = nn.Sequential(
        nn.Linear(projected_dim, projected_dim//2),
        nn.ReLU(inplace=True),
        nn.Dropout(dropout),
        nn.LayerNorm(projected_dim//2),
        nn.Linear(projected_dim//2, 64)
    )
    self.logit_layer = nn.Linear(64, 1)
    self.activation = nn.ReLU()

  def forward(self, x):
    """
    x : [batch, seq_len, input_dim]
    latent_rappresentation = [batch, seq_len, num_heads, head_dimension]
    """
    if x.ndim == 4:
      x = x.squeeze(2)
    emb = self.embedding(self.project(x))                    # embedding
    output_encoder = self.multi_head_block(emb, emb, emb)    # multihead
    y = self.layer_norm1(output_encoder + emb)               # residual + layerNorm
    x = self.activation(self.fully_connected(y))             # classifier
    x = torch.mean(x, 1)                                     # mean pooling
    x = self.logit_layer(x)                                  # logit
    return x



In [None]:

class TransformerArchitecture(pl.LightningModule):

    def __init__(self, input_dim=3, projected_dim=128):
        super(TransformerArchitecture, self).__init__()

        # training parameters:
        self.lr = 1e-3
        self.warmup = 0.01
        self.max_iters = 1000
        self.net = Transformer(
            input_dim=input_dim,
            projected_dim=projected_dim,
            num_heads=4,
            dropout=0.5)
        self.sigmoid = nn.Sigmoid()
        self.criterion = nn.BCELoss()
        self.accuracy_meter = torchmetrics.classification.Accuracy(
            task = "binary",
            num_classes = 2
         )

    def forward(self, x, mask=None):
        return self.net(x)


    @torch.no_grad()
    def get_attention_maps(self, x, mask=None, add_positional_encoding=True):
        """
        Function for extracting the attention matrices of the whole Transformer for a single batch.
        Input arguments same as the forward pass.
        """
        x = self.input_net(x)
        if add_positional_encoding:
            x = self.positional_encoding(x)
        attention_maps = self.transformer.get_attention_maps(x, mask=mask)
        return attention_maps

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        # Apply lr scheduler per step
        lr_scheduler = WarmupScheduler(optimizer,
                                        warmup=self.warmup,
                                        max_number_iters=self.max_iters)

        return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]

    def training_step(self, batch, batch_idx):

        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = self.net(inputs).squeeze(-1)

        probabilities = self.sigmoid(outputs)
        train_loss = self.criterion(probabilities,labels.float())
        train_accuracy = self.accuracy_meter(probabilities,labels.float())

        values = {"train_loss": train_loss, "train_acc": train_accuracy}
        self.log_dict(values, prog_bar = True)

        return train_loss

    def validation_step(self, batch, batch_idx):

        inputs, labels = batch
        inputs, labels = inputs.to(device),labels.to(device)
        outputs = self.net(inputs).squeeze(-1)

        probabilities = self.sigmoid(outputs)
        val_loss = self.criterion(probabilities,labels.float())
        val_accuracy = self.accuracy_meter(probabilities,labels.float())

        values = {"val_loss": val_loss, "val_acc": val_accuracy}
        self.log_dict(values, prog_bar = True)

        return val_loss

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = self.net(inputs).squeeze(-1)

        probabilities = self.sigmoid(outputs)
        test_loss = self.criterion(probabilities,labels.float())
        test_accuracy = self.accuracy_meter(probabilities,labels.float())

        values = {"test_loss": test_loss, "test_acc": test_accuracy}
        self.log_dict(values, prog_bar = True)

        return test_loss


## Baseline Evaluation

In [None]:

pure_features_datamodule = PureFeaturesDataModule(
    batch_size = 125,
    single_tensor = single_tensor,
    bifurcating_tensor = bifurcating_tensor
)
transformer_model = TransformerArchitecture(input_dim=3,
                                            projected_dim=8).to(device)
pure_features_datamodule.setup(None)
dataloader_train = pure_features_datamodule.train_dataloader()
pl.Trainer(max_epochs=1).fit(transformer_model, pure_features_datamodule)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name           | Type           | Params | Mode 
----------------------------------------------------------
0 | net            | Transformer    | 561    | train
1 | sigmoid        | Sigmoid        | 0      | train
2 | criterion      | BCELoss        | 0      | train
3 | accuracy_meter | BinaryAccuracy | 0      | train
----------------------------------------------------------
561       Trainable params
0         Non-trainable params
561       Total params
0.002     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The number of training batches (29) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
