In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
sequence_length = 20

In [3]:
class PatientDataset(Dataset):
    def __init__(self, num_patients, sequence_length, num_features):
        self.num_patients = num_patients
        self.sequence_length = sequence_length
        self.num_features = num_features

    def __len__(self):
        return self.num_patients

    def __getitem__(self, idx):
        heart_rate = torch.randint(low = 40, high = 140, size = (self.sequence_length, 1))
        heart_condition = torch.randint(low = 0, high = 4, size = (self.sequence_length, 1))
        pulse_rate = torch.randint(low = 60, high = 120, size = (self.sequence_length, 1))
        oxygen_level = torch.randint(low = 80, high = 100, size = (self.sequence_length, 1))
        medical_factors = torch.randn(self.sequence_length, self.num_features - 4, dtype=torch.float32)

        sequence = torch.cat((heart_rate, heart_condition, pulse_rate, oxygen_level, medical_factors), dim=1)
        return sequence

pdata = PatientDataset(num_patients = 100, sequence_length = 20, num_features = 5)

In [4]:
len(pdata)

100

In [5]:
pdata[1]

tensor([[ 4.9000e+01,  0.0000e+00,  1.0300e+02,  8.9000e+01, -1.6359e+00],
        [ 5.4000e+01,  3.0000e+00,  6.8000e+01,  9.5000e+01, -1.4221e+00],
        [ 5.3000e+01,  3.0000e+00,  1.0300e+02,  8.7000e+01,  8.2934e-01],
        [ 9.3000e+01,  0.0000e+00,  1.0700e+02,  9.6000e+01, -1.5719e-02],
        [ 1.0300e+02,  3.0000e+00,  8.3000e+01,  9.9000e+01, -1.4787e+00],
        [ 9.5000e+01,  2.0000e+00,  6.9000e+01,  9.9000e+01,  2.8943e-01],
        [ 6.6000e+01,  1.0000e+00,  6.1000e+01,  8.4000e+01, -7.3562e-01],
        [ 1.0200e+02,  2.0000e+00,  9.7000e+01,  9.7000e+01,  5.3776e-02],
        [ 9.0000e+01,  2.0000e+00,  7.8000e+01,  9.8000e+01,  1.1203e+00],
        [ 9.0000e+01,  2.0000e+00,  8.1000e+01,  9.2000e+01,  4.9537e-01],
        [ 1.3400e+02,  0.0000e+00,  6.9000e+01,  9.0000e+01,  1.3899e+00],
        [ 6.7000e+01,  1.0000e+00,  9.7000e+01,  9.4000e+01, -9.0376e-01],
        [ 7.7000e+01,  3.0000e+00,  1.0400e+02,  8.6000e+01, -5.9632e-01],
        [ 4.8000e+01,  1.

In [6]:
batch_size = 16
embedding_dim = 5

In [7]:
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, sequence_length, embedding_dim):
        super(LearnablePositionalEncoding, self).__init__()
        self.position_embeddings = nn.Parameter(torch.randn(sequence_length, embedding_dim))

    def forward(self, input_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_embeddings = self.position_embeddings[position_ids]
        return input_ids + position_embeddings

In [8]:
from torch.optim import Adam

num_epochs = 5

In [9]:
data_loader = DataLoader(pdata, batch_size=batch_size, shuffle=True)

LearnPE = LearnablePositionalEncoding(sequence_length, embedding_dim)
criterion = nn.MSELoss()
optimizer = Adam(LearnPE.parameters(), lr=0.001)

# learn_pe_optimizer = torch.optim.Adam(LearnPE.parameters(), lr=0.001)
# actual_model_optimizer = torch.optim.Adam(actual_model.parameters(), lr=0.001)
# combined_params = list(LearnPE.parameters()) + list(actual_model.parameters())
# optimizer = torch.optim.Adam(combined_params, lr=0.001)

for epoch in range(num_epochs):
    for batch in data_loader:
        optimizer.zero_grad()

        input_ids = batch 
        encoded_data = LearnPE(input_ids)
        
        #output = actual_model(encoded_data)
        #loss = criterion(output, target)
        # loss.backward()
        # optimizer.step()

        print(encoded_data)
        # print(loss.item())

    print(f"Epoch {epoch + 1}/{num_epochs} completed")

tensor([[[ 44.5234,   3.6773, 113.9422,  96.9759,  -0.6214],
         [ 68.1214,   2.6005, 118.5694,  84.3657,  -0.6933],
         [118.9975,   0.5904,  95.7787,  80.5550,  -0.6045],
         ...,
         [ 52.0974,   1.5757,  73.8304,  88.7863,  -0.2643],
         [100.4841,  -0.7639, 112.7743,  84.1079,  -3.0372],
         [120.4638,   0.6510, 116.3238,  96.8653,  -1.6880]],

        [[ 46.5234,   2.6773, 111.9422,  84.9759,  -1.0367],
         [115.1214,   1.6005, 106.5694,  99.3657,   0.6642],
         [106.9975,   3.5904,  87.7787,  88.5550,  -0.3288],
         ...,
         [ 83.0975,   2.5757, 117.8304,  85.7863,   1.2913],
         [120.4841,   2.2361, 119.7743,  98.1079,  -2.2664],
         [ 43.4638,   2.6510,  60.3237,  88.8653,  -0.3202]],

        [[ 77.5234,   2.6773, 107.9422,  87.9759,  -0.9977],
         [108.1214,   1.6005,  79.5694,  87.3657,  -0.3648],
         [ 99.9975,   0.5904,  98.7787,  83.5550,  -1.2103],
         ...,
         [105.0975,   2.5757, 109.8304,