<a href="https://colab.research.google.com/github/foxtrotmike/musings/blob/main/vtransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import pandas as pd

def generate_biomedical_data(num_patients, max_time_points, measurement_dim):
    # Initialize a list to store data
    data = []

    # Generate data for each patient
    for patient_id in range(1, num_patients + 1):
        # Generate random number of time points for each patient
        num_time_points = np.random.randint(5, max_time_points + 1)

        # Generate unevenly spaced time points
        times = np.sort(np.random.choice(range(1, 100), size=num_time_points, replace=False))

        # Generate measurements for each time point
        measurements = np.random.rand(num_time_points, measurement_dim) * 100

        # Create data for each time point
        for time, measurement in zip(times, measurements):
            data.append([patient_id, time, *measurement])

    # Convert to DataFrame
    column_names = ['patient_id', 'time'] + [f'measurement_{i+1}' for i in range(measurement_dim)]
    df = pd.DataFrame(data, columns=column_names)

    return df

# Generate a toy dataset
num_patients = 10
max_time_points = 20
measurement_dim = 3  # For example, could represent different types of biomedical readings
df = generate_biomedical_data(num_patients, max_time_points, measurement_dim)

# Display the first few rows of the dataset
print(df.head(50))


import torch
import torch.nn as nn
import torch.optim as optim

class TimeSeriesTransformer(nn.Module):
    def __init__(self, feature_size, num_layers, num_heads, dropout_rate=0.1):
        super(TimeSeriesTransformer, self).__init__()
        self.transformer = nn.Transformer(
            d_model=feature_size,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dropout=dropout_rate,
            batch_first=True
        )
        self.output_layer = nn.Linear(feature_size, feature_size)

    def forward(self, src, tgt, src_mask, tgt_mask):
        # src and tgt are the input and target sequences
        # src_mask and tgt_mask are the padding and causal masks respectively
        out = self.transformer(src, tgt, src_key_padding_mask=src_mask, tgt_mask=tgt_mask, memory_key_padding_mask=src_mask)
        return self.output_layer(out)

def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.ones(sz, sz, device=src.device)).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

#def train(transformer, data_loader, loss_fn, optimizer, epochs=10):


# Example use case: Assume `data_loader` is your PyTorch DataLoader that yields batches of data.
# This is a minimalistic and illustrative example. In practice, you'll need to define the dataset class,
# and handle batching, and possibly GPU computations.
feature_size = 3  # Same as measurement dimensions
num_layers = 3
num_heads = 3
transformer = TimeSeriesTransformer(feature_size, num_layers, num_heads)
optimizer = optim.Adam(transformer.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

# Mock data loader (replace this with actual data preparation)
src = torch.rand(10, 20, feature_size)  # (batch_size, sequence_length, feature_size)
tgt = torch.rand(10, 20, feature_size)  # Same shape as src for simplicity
data_loader = [(src, tgt)]

# Train the model
transformer.train()
for epoch in range(100):
    for src, tgt in data_loader:
        src_mask = None  # No source masking in this autoregressive task
        tgt_input = tgt[:, :-1]  # Use all but the last token for input to predict the next token
        tgt_output = tgt[:, 1:]  # Use all but the first token for the target output
        tgt_mask = generate_square_subsequent_mask(tgt_input.size(1))  # Mask for the input part

        optimizer.zero_grad()
        output = transformer(src, tgt_input, src_mask, tgt_mask)
        loss = loss_fn(output, tgt_output)
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')

    patient_id  time  measurement_1  measurement_2  measurement_3
0            1     6      28.214488      25.393803       1.788428
1            1    21       2.575569      57.769994      46.926204
2            1    52      43.186728      63.535987      11.080344
3            1    59      78.745628      45.087964      35.086413
4            1    63      21.947808      96.519428      11.337316
5            1    67      33.094361      72.345273      62.391860
6            2     7      77.289804      56.127689      75.777841
7            2    28       0.466625      92.337506       6.906353
8            2    29      27.838122      65.696205      39.141420
9            2    33       7.986671      88.166573      86.680852
10           2    50      25.640740      56.371510       8.314938
11           2    55      93.037869      91.857094      33.617259
12           2    97      38.875356      98.153366      51.192967
13           3     5      81.554330      25.355395      99.984621
14        



Epoch 1, Loss: 0.7607957720756531
Epoch 2, Loss: 0.17592217028141022
Epoch 3, Loss: 0.16767065227031708
Epoch 4, Loss: 0.15558499097824097
Epoch 5, Loss: 0.1413862258195877
Epoch 6, Loss: 0.1255502998828888
Epoch 7, Loss: 0.11852536350488663
Epoch 8, Loss: 0.11318311095237732
Epoch 9, Loss: 0.10218221694231033
Epoch 10, Loss: 0.09718501567840576
Epoch 11, Loss: 0.09024440497159958
Epoch 12, Loss: 0.08566103130578995
Epoch 13, Loss: 0.08285683393478394
Epoch 14, Loss: 0.08484569936990738
Epoch 15, Loss: 0.08462855219841003
Epoch 16, Loss: 0.08179233223199844
Epoch 17, Loss: 0.08445437997579575
Epoch 18, Loss: 0.08609229326248169
Epoch 19, Loss: 0.08584374189376831
Epoch 20, Loss: 0.08617661148309708
Epoch 21, Loss: 0.08712138235569
Epoch 22, Loss: 0.08769042044878006
Epoch 23, Loss: 0.08646060526371002
Epoch 24, Loss: 0.08669541776180267
Epoch 25, Loss: 0.08639346808195114
Epoch 26, Loss: 0.08558256179094315
Epoch 27, Loss: 0.08509645611047745
Epoch 28, Loss: 0.08375903964042664
Epoch 2