In [1]:
import torch
import torch.nn as nn

In [2]:
class NonEqui_mv_linear(nn.Module):

    def __init__(self, in_channels, out_channels):

        super(NonEqui_mv_linear, self).__init__()
        self.MV_dim = 16
        
        self.w = torch.randn(out_channels, in_channels, self.MV_dim)
        self.bias = torch.randn(out_channels, self.MV_dim)

        self.w = nn.Parameter(self.w)
        self.bias = nn.Parameter(self.bias)

        self.register_parameter(name="weight", param=self.w)
        self.register_parameter(name="bias", param=self.bias)
    def forward(self, x):

        # expected x shape :
        #[batch, seq_len, channels, mv_dim]
        out = torch.einsum('b s c m, C c m -> b s C m', x, self.w)
        
        # add batch by broadcasting
        out = out + self.bias
        return out

In [3]:
x = torch.randn(32, 50, 4, 16)
model = NonEqui_mv_linear(4, 8)
out = model(x)

In [4]:
out.shape

torch.Size([32, 50, 8, 16])

In [7]:
x = torch.randn(32, 50, 4, 16)
linear = NonEqui_mv_linear(in_channels=4, out_channels=4//2, batch=x.shape[0], seq_len=x.shape[1])
params = [4, 4//2, x.shape[0], x.shape[1]]
Q = NonEqui_mv_linear(*params)
K = NonEqui_mv_linear(*params)
V = NonEqui_mv_linear(*params)

softmax = nn.functional.softmax

# the multi_head proj ??

In [None]:
class MultiHeadBlock(nn.Module):

    def __init__(self, emb_dim, n_head):
        super(MultiHeadBlock, self).__init__()
        
        if emb_dim % n_head != 0:
            raise ValueError("emb dim not compatible with n_head")
        self.input_size = emb_dim
        self.n_head = n_head
        self.head_dim = self.input_size // n_head

        self.Q_layer = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.K_layer = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.V_layer = nn.Linear(self.head_dim, self.head_dim, bias=False)

        self.softmax = nn.functional.softmax
        self.multihead_projection = nn.Linear(emb_dim ,self.input_size)

    def forward(self, input):
        """
        input shape = [batch, seq_len, emb_dim]
        """
        batch, seq_len = input.shape[0], input.shape[1]
        # original x shape = [batch, seq_len, channels, 16]
        #x = input.reshape(batch_size, seq_len, self.n_head, self.head_dim)
        
        x = x.permute(0, 2, 1, 3) #  [batch, seq_len, channels, emb_dim]
        q, k, v = x.clone(), x.clone(), x.clone()

        # out : batch, seq_len, channels, emb_dim
        q = self.Q_layer(q)
        k = self.K_layer(k)
        v = self.V_layer(v)

        # out batch channels seq_len, emb
        k_T = k.permute(0, 2, 1, 3)
        # old   # [batch, n_head, head_dim, seq_len]
        #        k_T = k.permute(0, 1, 3, 2) 

        # self attention
        # --------------
        attention_w = self.softmax(torch.matmul(q, k_T)/self.head_dim**0.5, dim=-1)
        attention_score = torch.matmul(attention_w, v) # [batch, n_head, seq_len, head_dim]

        # come back to [batch, seq_len, n_head, emb_size]
        attention_score = attention_score.permute(0, 2, 1, 3)

        # squeeze to input shape
        attention_score = attention_score.reshape(batch, seq_len, self.input_size)

        # project to original input size
        attention_out = self.multihead_projection(attention_score)

        # combine attention and input = skipp connection
        return attention_out + input


class FullyConnected(nn.Module):
    def __init__(self, emb_dim, ffn_emb):
        super(FullyConnected, self).__init__()

        self.ffn = nn.Sequential(
            nn.Linear(emb_dim, ffn_emb),
            nn.ReLU(), 
            nn.Linear(ffn_emb, emb_dim),
            nn.LayerNorm(emb_dim),
        )

    def forward(self, x):
        return x + self.ffn(x)
    
# ---------------
# test
x = torch.randn(32, 4, 16)
emb_size = 64
transformer_block = nn.Sequential(
    nn.Linear(16, emb_size), 
    MultiHeadBlock(emb_size,n_head=4),              # contains skipp connection
    nn.LayerNorm(emb_size), 
    FullyConnected(emb_size, ffn_emb=emb_size*4),   # fully connected block + skipp connection 
    nn.LayerNorm(emb_size)
    
)
transformer_block(x).shape


torch.Size([32, 4, 64])

In [None]:
x = torch.randn(32, 4, 16)
emb_size = 64

class FullyConnected(nn.Module):
    def __init__(self, emb_dim, ffn_emb):
        super(FullyConnected, self).__init__()

        self.ffn = nn.Sequential(
            nn.LayerNorm(emb_dim),
            nn.Linear(emb_dim, ffn_emb),
            nn.ReLU(), 
            nn.Linear(ffn_emb, emb_dim),
        )

    def forward(self, x):
        return x + self.ffn(x)
    

transformer_block = nn.Sequential(
    nn.Linear(16, emb_size),                      # project into latent dim
    MultiHeadBlock(emb_size,n_head=4),            # multihead + skipp connection
    FullyConnected(emb_size, ffn_emb=emb_size*4), # fully connected block + skipp connection 
    nn.LayerNorm(emb_size)
)
transformer_block(x)

class NonEquivarTransformer(nn.Module):

    def __init__(self, emb_size, n_block=1):
        
        super(NonEquivarTransformer, self).__init__()

        transformer_block = nn.Sequential(
            nn.Linear(16, emb_size),                      # project into latent dim
            MultiHeadBlock(emb_size,n_head=4),            # multihead + skipp connection
            FullyConnected(emb_size, ffn_emb=emb_size*4), # fully connected block + skipp connection 
            nn.LayerNorm(emb_size)
        )
        self.transformer = nn.Sequential(
            *[transformer_block for i in range(n_block)]  # -> *[] unpack the list 
        )
        self.classifier = nn.Sequential(
            nn.Linear(emb_size, 1), 
            # nn.Sigmoid() --> not used, just return logit : numerical stability in the loss
        )

    def forward(self, x):
        x = self.transformer(x)
        x = x.mean(dim=1) # mean pooling across the seq_len dim
        y = self.classifier(x)
        return y

classifier = NonEquivarTransformer(emb_size=64)


In [None]:
classifier

NonEquivarTransformer(
  (transformer): Sequential(
    (0): Sequential(
      (0): Linear(in_features=16, out_features=64, bias=True)
      (1): MultiHeadBlock(
        (Q_layer): Linear(in_features=64, out_features=64, bias=False)
        (K_layer): Linear(in_features=64, out_features=64, bias=False)
        (V_layer): Linear(in_features=64, out_features=64, bias=False)
        (softmax): Softmax(dim=-1)
        (multihead_projection): Linear(in_features=64, out_features=64, bias=True)
      )
      (2): FullyConnected(
        (ffn): Sequential(
          (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=64, out_features=256, bias=True)
          (2): ReLU()
          (3): Linear(in_features=256, out_features=64, bias=True)
        )
      )
      (3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [None]:
criterion = nn.BCEWithLogitsLoss()
y = torch.randint(high=2, size=(y_hat.shape[0],)).reshape(y_hat.shape).to(torch.float32)
optimizer = torch.optim.Adam(classifier.parameters())

for i in range(100):
    y_hat = classifier(x)
    loss = criterion(y_hat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
y_hat

tensor([[ 6.7776],
        [-6.4134],
        [ 6.6611],
        [-6.5333],
        [ 6.7147],
        [-6.3710],
        [-6.3725],
        [-6.2817],
        [-6.4525],
        [-6.4585],
        [-6.7606],
        [ 6.7850],
        [-6.3850],
        [ 6.3662],
        [-6.3539],
        [-6.5630],
        [-6.7570],
        [-6.7158],
        [ 6.3337],
        [ 6.8090],
        [ 6.2347],
        [ 6.5377],
        [ 6.4298],
        [-6.4829],
        [ 6.5602],
        [-6.6577],
        [ 6.2722],
        [-6.7036],
        [ 6.3142],
        [ 6.4708],
        [-6.7307],
        [-6.6444]], grad_fn=<AddmmBackward0>)

In [None]:
nn.functional.sigmoid(y_hat)

tensor([[0.9989],
        [0.0016],
        [0.9987],
        [0.0015],
        [0.9988],
        [0.0017],
        [0.0017],
        [0.0019],
        [0.0016],
        [0.0016],
        [0.0012],
        [0.9989],
        [0.0017],
        [0.9983],
        [0.0017],
        [0.0014],
        [0.0012],
        [0.0012],
        [0.9982],
        [0.9989],
        [0.9980],
        [0.9986],
        [0.9984],
        [0.0015],
        [0.9986],
        [0.0013],
        [0.9981],
        [0.0012],
        [0.9982],
        [0.9985],
        [0.0012],
        [0.0013]], grad_fn=<SigmoidBackward0>)

In [None]:
y_hat

torch.Size([32, 1])

In [None]:
import pytorch_lightning as pl
from dataclasses import dataclass


@dataclass
class Params:

    batch_size: int = 32
    learning_rate: float = 0.001
    
    emb_size: int =     64 # projection dim for equilinear layer


class ClassifierLightning(pl.LightningModule):

    def __init__(self, params):
        super(ClassifierLightning, self).__init__()
        self.emb_size = params.emb_size
        self.lr = params.learning_rate
        self.batch_size = params.batch_size

        self.model = NonEquivarTransformer(emb_size=emb_size)
        self.criterion = nn.BCEWithLogitsLoss()
        

    def forward(self, x):
        self.model(x)

    def training_step(self, batch):
        x, y = batch
        y_hat = self(x)  # call the forward
        self.criterion(y_hat, y.view(-1, 1))
        self.log("train loss", loss)
        return loss

    def validation_step(self, batch):
        x, y = batch
        y_hat = self(x)  # call the forward
        self.criterion(y_hat, y.view(-1, 1))
        self.log("val loss", loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)



In [None]:
params = Params()
light_model = ClassifierLightning(params)

In [None]:
from torch.utils.data import DataLoader, TensorDataset

x_train = torch.rand(32, 4, 16)  # 100 samples
y_train = (torch.rand(32) > 0.5).float()  # Binary labels

train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

x_val = torch.rand(32, 4, 16)  # Validation data
y_val = (torch.rand(32) > 0.5).float()  # Binary labels
val_dataset = TensorDataset(x_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=16)

    # Define trainer
trainer = pl.Trainer(max_epochs=10)

    # Fit the model
trainer.fit(light_model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: /home/dema/Project/GAT/src/Transformer_standard/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                  | Params
----------------------------------------------------
0 | model     | NonEquivarTransformer | 50.9 K
1 | criterion | BCEWithLogitsLoss     | 0     
----------------------------------------------------
50.9 K    Trainable params
0         Non-trainable params
50.9 K    Total params
0.204     Total 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/dema/Project/GAT/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


TypeError: ClassifierLightning.validation_step() takes 1 positional argument but 2 were given