# In google colab!

In [None]:
! pip install learn2learn

In [None]:
import learn2learn as l2l

In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! cp /content/drive/MyDrive/gin_depmap_transfer_learning/normalized_no_leakage.zip .
! unzip normalized_no_leakage.zip

Archive:  normalized_no_leakage.zip
  inflating: depmap_crispr_zscore.csv  
  inflating: depmap_expression_lfc_zscore.csv  
  inflating: hap1_expression_lfc.csv  
  inflating: hap1_crispr.csv         


In [None]:
norm_dir = ""

hap1_expression_lfc = pd.read_csv(norm_dir + "hap1_expression_lfc.csv", index_col=0)
hap1_crispr = pd.read_csv(norm_dir + "hap1_crispr.csv", index_col=0)
depmap_expression_lfc_zscore = pd.read_csv(norm_dir + "depmap_expression_lfc_zscore.csv", index_col=0)
depmap_crispr_zscore = pd.read_csv(norm_dir + "depmap_crispr_zscore.csv", index_col=0)
hap1_expression_lfc.shape, hap1_crispr.shape, depmap_expression_lfc_zscore.shape, depmap_crispr_zscore.shape

((60, 16372), (60, 16432), (1021, 16372), (1021, 16432))

In [None]:
! git clone https://github.com/danielchang2002/GI_transfer_learning
%cd GI_transfer_learning/src
from utils import *
from mlp import MLP

In [None]:
data = hap1_expression_lfc
labels = hap1_crispr
data2 = depmap_expression_lfc_zscore
labels2 = depmap_crispr_zscore

In [None]:
tissue_info = pd.read_csv("Model.csv", index_col=0)["OncotreeLineage"].loc[data2.index]
tissue_counts = tissue_info.value_counts()

In [None]:
tissue_counts.shape[0]

27

In [None]:
min_tissue_size = 16

def get_task():
    tissue = tissue_counts[tissue_counts >= min_tissue_size].sample().index[0]

    tissue_expression = data2[tissue_info == tissue]
    tissue_crispr = labels2[tissue_info == tissue]

    num_samples_in_tissue = tissue_expression.shape[0]
    batch_size = 16
    # batch_size = num_samples_in_tissue

    tissue_expression = tissue_expression.sample(batch_size)
    tissue_crispr = tissue_crispr.sample(batch_size)

    train_tissue_expression = tissue_expression[:batch_size // 2]
    train_tissue_crispr = tissue_crispr[:batch_size // 2]

    test_tissue_expression = tissue_expression[batch_size // 2:]
    test_tissue_crispr = tissue_crispr[batch_size // 2:]

    return torch.Tensor(train_tissue_expression.values).to(device="cuda"), torch.Tensor(train_tissue_crispr.values).to(device="cuda"), torch.Tensor(test_tissue_expression.values).to(device="cuda"), torch.Tensor(test_tissue_crispr.values).to(device="cuda")

In [None]:
min_tissue_size = 16

def get_task2():

    tissues = list(tissue_counts.index)
    np.random.shuffle(tissues)

    train_tissues = tissues[: len(tissues) // 2]
    test_tissues = tissues[len(tissues) // 2 :]

    train_tissue_expression = data2[tissue_info.isin(train_tissues)]
    train_tissue_crispr = labels2[tissue_info.isin(train_tissues)]

    test_tissue_expression = data2[tissue_info.isin(test_tissues)]
    test_tissue_crispr = labels2[tissue_info.isin(test_tissues)]

    return torch.Tensor(train_tissue_expression.values).to(device="cuda"), torch.Tensor(train_tissue_crispr.values).to(device="cuda"), torch.Tensor(test_tissue_expression.values).to(device="cuda"), torch.Tensor(test_tissue_crispr.values).to(device="cuda")

In [None]:
! nvidia-smi

Wed May  8 18:51:04 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   68C    P0              28W /  70W |  10317MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
import learn2learn as l2l

In [None]:
import torch.optim as optim

TIMESTEPS = 100
TASKS_PER_STEP = 10
fine_tune_steps = 1

model = MLP(data2.shape[1], labels.shape[1]).to(device="cuda")
maml = l2l.algorithms.MAML(model, lr=1e-2, first_order=False)
opt = optim.Adam(maml.parameters(), lr=1e-3)
loss_func = nn.MSELoss()

l2_weight1 = 0
l2_weight2 = 0.01

In [None]:
TIMESTEPS = 1000

for i in range(TIMESTEPS):

    step_loss = 0.0

    outputs = []
    test_crisprs = []

    for t in range(TASKS_PER_STEP):
        train_tissue_expression, train_tissue_crispr, test_tissue_expression, test_tissue_crispr = get_task()

        # Adaptation: Instantiate a copy of model
        learner = maml.clone()

        # Adaptation: Compute and adapt to task loss
        for _ in range(fine_tune_steps):
        # for _ in range(1):
            output = learner(train_tissue_expression)
            loss = loss_func(output, train_tissue_crispr) + l2_weight1 * (learner.fc1.weight.norm() + learner.fc2.weight.norm())
            learner.adapt(loss)

        # Adaptation: Evaluate the effectiveness of adaptation
        output = learner(test_tissue_expression)

        adapt_loss = loss_func(output, test_tissue_crispr) + l2_weight2 * (learner.fc1.weight.norm() + learner.fc2.weight.norm())

        outputs.append(output)
        test_crisprs.append(test_tissue_crispr)

        # Accumulate the error over all tasks
        step_loss += adapt_loss

    outputs = torch.vstack(outputs).cpu().detach().numpy()
    test_crisprs = torch.vstack(test_crisprs).cpu().detach().numpy()

    # Meta-learning step: compute gradient through the adaptation step, automatically.
    step_loss = step_loss / TASKS_PER_STEP
    opt.zero_grad()
    step_loss.backward()
    opt.step()

    corrs = [np.corrcoef(outputs[i], test_crisprs[i])[0, 1] for i in range(outputs.shape[0])]
    corr = np.mean(corrs)

    print(i, step_loss.item(), corr, end=" ")
    print()

0 1.9211090803146362 -0.0005295735185442602 
1 2.951321601867676 -0.00422061782429232 
2 2.583693265914917 0.001244518840555361 
3 2.5528512001037598 -0.002799777724811843 
4 2.297614574432373 0.0030313017238701905 
5 2.550607442855835 0.0029332634139082477 
6 2.455620527267456 0.00012753865121573415 
7 2.3943257331848145 -0.0006501665718912051 
8 2.2364814281463623 0.009494920454239391 
9 2.738168239593506 0.0077271743305125885 
10 2.819582939147949 0.024862143982614007 
11 2.1507322788238525 0.005892407433799029 
12 2.2377662658691406 0.008014785888014119 
13 2.210594415664673 0.014796200703217327 
14 2.4311721324920654 0.013651878298487333 
15 2.1014697551727295 0.007878442223719837 
16 2.2003300189971924 0.010716305283650335 
17 2.1228034496307373 0.00792943970673653 
18 2.2817699909210205 0.014242771155133536 
19 2.1324334144592285 0.014645538012466352 
20 2.292640209197998 0.0192976423996562 
21 2.109809160232544 0.025654917942361844 
22 2.1012017726898193 0.02603342069035571 
23

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "depmap_MLP_MAML.pt")