In [2]:
import torch

batch_size = 32

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
from torch.utils.data import Dataset
from torch import tensor
import polars as pl
import time

class Sites(Dataset):

    chrsm_lengths = {
        "1": 30427671,
        "2": 19698289,
        "3": 23459830,
        "4": 18585056,
        "5": 26975502,
    }

    def chrsms(self):
        return [1, 2, 3, 4] if self.mode == 'train' else [5]

    def __init__(self, mode='train', width = 256):
        self.mode = mode
        self.width = width
        self.features = pl.DataFrame()
        self.labels = pl.DataFrame()
        for chrsm in self.chrsms():
            df =  pl.read_parquet(f"./embeddings/chr_{chrsm}.parquet")
            self.features = self.features.vstack(df)

        for chrsm in self.chrsms():
            df =  pl.read_parquet(f"./labels/{chrsm}.parquet")
            self.labels = self.labels.vstack(df)
        self.labels = self.labels.drop(["sequence", "std_st"])


        self.features = self.features.to_numpy()
        self.labels = self.labels.to_numpy()
        

    def __len__(self):
        return len(self.features) - 2 * self.width
    
    def __getitem__(self, idx):
        start = idx
        end = idx + 2 * self.width
        f = self.features[start:end]
        l = self.labels[idx+self.width]

        f = torch.tensor(f, device=device, dtype=torch.float32).reshape(-1)
        l = torch.squeeze(torch.tensor(l, device=device, dtype=torch.float32))

        return f, l
    


        


In [4]:
from torch.utils.data import DataLoader

train_dataset = Sites(mode='train')
test_dataset = Sites(mode='test')


train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [5]:
from torch import nn
from torch.nn import functional as F

class MethylationMaster(nn.Module):

    def steady_state(alpha, beta):
      
        pi1 = lambda a, b:  (a * ((1.0 - a) ** 2  - (1.0 - b)**(2) - 1.0)) / ((a + b) * ((a + b - 1.0)**(2) - 2.0))
        pi2 = lambda a, b:  (4.0 * a * b * (a + b - 2.0)) / ((a + b) * ((a + b - 1.0)**(2) - 2.0));
    
        return pi1(alpha, beta) + 0.5 * pi2(alpha, beta)



    def __init__(self):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=5120, nhead=80, batch_first=True)
        # self.start = nn.Linear(5120, 512)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
        self.linear = nn.Linear(5120, 2)
        self.model = nn.Sequential(self.encoder, self.linear)
        

    def forward(self, features, targets = None):
        logits =  self.model(features) # (B, (alpha, beta))
       
        if targets is None:
            return logits, None

        loss = F.mse_loss(logits, targets)

        return logits, loss
    
    # @torch.no_grad()
    # def generate(self, idx):
    #     logits = self.transformer(idx, idx) # (B * S, C)

    #     last = logits[-1]
    #     last_tar = targets[-1]

    #     probs = F.softmax(last, dim=-1) 
    #     # sample from the distribution
    #     guess = torch.multinomial(probs, num_samples=1) # (B, 1)


    #     print(f"Guess: {guess}, correct: {last_tar}")

    #     return guess == last_tar

In [6]:
conschti = MethylationMaster().to(device)

print(sum(p.numel() for p in conschti.parameters())/1e6, 'M parameters')
optimizer = torch.optim.Adam(conschti.parameters(), lr=1e-3)


503.519234 M parameters


In [8]:
epochs = 5
test_interval = 100
run_name = "encoder_only_755M"

test_dataloader = iter(test_dataloader)
num = len(train_dataloader)
print(num)

#reset log file
with open(f"log/{run_name}.txt", "w") as f:
    f.write("epoch,step,train_loss,test_loss\n")

for epoch in range(epochs):
    for i, (feature, label) in enumerate(train_dataloader): #batched
        optimizer.zero_grad()

        logits, loss = conschti(feature , label)
        loss.backward()
        optimizer.step()

        percent_done = (i / num) * 100
        percent_done = '{:.4f}'.format(percent_done)

        print(f"Step {i} ({percent_done}%), loss: {loss.item()}, alpha: {logits[0][0]}, beta: {logits[0][1]}, steady_state: {MethylationMaster.steady_state(logits[0][0], logits[0][1])}")

        if i % test_interval == 0:
            with torch.no_grad():
                    feature, label = next(test_dataloader)
                    logits, test_loss = conschti(feature , label)
                    print(f"Validation loss: {loss.item()}")

                    with open(f"log/{run_name}.txt", "a") as f:
                      f.write(f"{epoch},{i},{loss},{test_loss}\n")


2880323
Step 0 (0.0000%), loss: 0.03843260183930397, alpha: 0.27837392687797546, beta: 0.11482207477092743
Validation loss: 0.03843260183930397
Step 1 (0.0000%), loss: 0.018323194235563278, alpha: 0.15977507829666138, beta: 0.13840101659297943
Step 2 (0.0001%), loss: 0.022135496139526367, alpha: 0.23037387430667877, beta: 0.06178037077188492
Step 3 (0.0001%), loss: 0.016016550362110138, alpha: -0.08077594637870789, beta: 0.04442596808075905
Step 4 (0.0001%), loss: 0.014578926376998425, alpha: 0.07141273468732834, beta: 0.07230084389448166
Step 5 (0.0002%), loss: 0.02764589712023735, alpha: -0.04781973361968994, beta: -0.04489924758672714
Step 6 (0.0002%), loss: 0.02301700785756111, alpha: 0.12388560175895691, beta: 0.06368600577116013
Step 7 (0.0002%), loss: 0.016085147857666016, alpha: 0.27394604682922363, beta: 0.15755964815616608
Step 8 (0.0003%), loss: 0.04460945725440979, alpha: -0.1817365437746048, beta: -0.09079090505838394
Step 9 (0.0003%), loss: 0.018867190927267075, alpha: 0.