In [1]:
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 

import torch 
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.optim as optim 
from datasets import load_dataset 

import os 
import h5py
from tqdm import tqdm
import pyarrow
import time 
import json 

import sys 
sys.path.append('../code')
import utils
import model 
import loss 
from custom_dataset import CustomDataset
import train


In [2]:
(gene_dict, dataset_gene, dataset_gene_ids) = utils.generate_gene_dic()

tokenizer = utils.tokenizer_v1(gene_dict= gene_dict,
                         dataset_gene= dataset_gene,
                         dataset_gene_ids= dataset_gene_ids) 

vocab_size = tokenizer.vocab_size
print(vocab_size)

tokenizer.add_token(token = '<cls>')
tokenizer.add_token(token = '<pad>')
#tokenizer.gene_dict['<cls>'] = vocab_size
print(tokenizer.vocab_size)
print(tokenizer.gene_dict['<cls>']) 
print(tokenizer.gene_dict['<pad>']) 

33524
33526
33524
33525


In [4]:

collate_fn = utils.collater(tokenizer= tokenizer, max_expression= 100, mask_ratio = 0.1, 
                            max_num = 2000,  rho = 0.1, pad_idx = tokenizer.gene_dict['<pad>'])

In [5]:

#dataset_1  = load_dataset(path = 'mus_brain', cache_dir = 'huggingface_cache')
dataset_1 = load_dataset(path = '/work/sunrui/pretrain_dataset/allen_2021_data',
                         cache_dir = '/work/sunrui/huggingface')
dataset_2 = load_dataset(path = '/work/sunrui/pretrain_dataset/allen_2023_data', 
                         cache_dir = '/work/sunrui/huggingface') 

dataset_1 = dataset_1['train'].select(range(3000)).train_test_split(test_size = 0.05)
dataset_2= dataset_2['train'].select(range(3000)).train_test_split(test_size = 0.05)

train_dataset_1, test_dataset_1 = dataset_1['train'], dataset_2['test']
train_dataset_2, test_dataset_2 = dataset_2['train'], dataset_2['test']

train_dataset = CustomDataset([train_dataset_1, train_dataset_2]) 
test_dataset = CustomDataset([test_dataset_1, test_dataset_2]) 


Resolving data files:   0%|          | 0/117 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/117 [00:00<?, ?it/s]

  table = cls._concat_blocks(blocks, axis=0)


Resolving data files:   0%|          | 0/419 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/412 [00:00<?, ?it/s]

In [6]:
count_embedding_num = 104
gene_embedding_num = tokenizer.vocab_size

d_model = 256
gene_padding_idx = tokenizer.gene_dict['<pad>']
count_padding_idx = 103
n_head = 8
dim_ffn = 4*d_model
dropout = 0.1
layer_norm_eps =1e-5
batch_first = True
norm_first = False
num_layers = 8
norm = None
num_hiddens = 256

my_model = model.sc_pretrain(count_embedding_num,
                 gene_embedding_num,
                 d_model,
                 gene_padding_idx,
                 count_padding_idx,
                 n_head,
                 dim_ffn,
                 dropout,
                 layer_norm_eps,
                 batch_first,
                 norm_first,
                 num_layers,
                 norm,
                 num_hiddens) 



In [7]:

# 创建 DataLoader 实例
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn= collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn= collate_fn)


In [8]:

pretrain_loss = loss.pretrain_loss()

my_model = train.train_multi_epoch(my_model, 
                train_loader,
                test_loader,
                pretrain_loss, 
                #optimizer = optim.SGD(my_model.parameters(), lr=1e-4, momentum=0.9),
               optimizer = optim.Adam(my_model.parameters(), lr = 5e-5, weight_decay=0.01)
                device = 'cuda',
                gradient_accumulation_steps = 8,
                save_steps = 100,
                save_dir = 'test_1',
                epochs = 5)



  batch_data['counts_0'] = torch.tensor(batch_data['counts_0'], dtype = torch.int)


Step 1, Loss: 6.3907, Exp_loss : 5.1032, Clip_loss ; 1.2875
Step 2, Loss: 6.3051, Exp_loss : 5.0146, Clip_loss ; 1.2906
Step 3, Loss: 6.1143, Exp_loss : 4.8274, Clip_loss ; 1.2869
Step 4, Loss: 5.8445, Exp_loss : 4.5789, Clip_loss ; 1.2656
Step 5, Loss: 5.5032, Exp_loss : 4.2857, Clip_loss ; 1.2175
Step 6, Loss: 5.3384, Exp_loss : 4.0509, Clip_loss ; 1.2874
Step 7, Loss: 4.9877, Exp_loss : 3.7539, Clip_loss ; 1.2338
Step 8, Loss: 4.7551, Exp_loss : 3.5503, Clip_loss ; 1.2048
Step 9, Loss: 4.5518, Exp_loss : 3.4357, Clip_loss ; 1.1160
Step 10, Loss: 4.3712, Exp_loss : 3.2551, Clip_loss ; 1.1162
Step 11, Loss: 4.4349, Exp_loss : 3.2084, Clip_loss ; 1.2265
Step 12, Loss: 4.1086, Exp_loss : 3.1109, Clip_loss ; 0.9977
Step 13, Loss: 3.9702, Exp_loss : 2.9074, Clip_loss ; 1.0629
Step 14, Loss: 3.8383, Exp_loss : 2.9057, Clip_loss ; 0.9325
Step 15, Loss: 3.7225, Exp_loss : 2.9009, Clip_loss ; 0.8216
Step 16, Loss: 3.7319, Exp_loss : 2.8022, Clip_loss ; 0.9297
Step 17, Loss: 3.5798, Exp_loss :

Step 135, Loss: 2.9966, Exp_loss : 2.5398, Clip_loss ; 0.4568
Step 136, Loss: 2.7800, Exp_loss : 2.3455, Clip_loss ; 0.4345
Step 137, Loss: 2.7462, Exp_loss : 2.3348, Clip_loss ; 0.4115
Step 138, Loss: 2.7929, Exp_loss : 2.4969, Clip_loss ; 0.2961
Step 139, Loss: 2.7619, Exp_loss : 2.3305, Clip_loss ; 0.4314
Step 140, Loss: 2.8022, Exp_loss : 2.3888, Clip_loss ; 0.4134
Step 141, Loss: 2.7013, Exp_loss : 2.4037, Clip_loss ; 0.2976
Step 142, Loss: 2.8123, Exp_loss : 2.3506, Clip_loss ; 0.4617
Step 143, Loss: 2.7401, Exp_loss : 2.4081, Clip_loss ; 0.3320
Step 144, Loss: 2.7661, Exp_loss : 2.3560, Clip_loss ; 0.4101
Step 145, Loss: 3.0306, Exp_loss : 2.4396, Clip_loss ; 0.5910
Step 146, Loss: 2.7468, Exp_loss : 2.3585, Clip_loss ; 0.3883
Step 147, Loss: 2.7162, Exp_loss : 2.3087, Clip_loss ; 0.4075
Step 148, Loss: 2.6127, Exp_loss : 2.3188, Clip_loss ; 0.2939
Step 149, Loss: 2.7541, Exp_loss : 2.3980, Clip_loss ; 0.3561
Step 150, Loss: 2.8425, Exp_loss : 2.4658, Clip_loss ; 0.3767
Step 151

  return torch._transformer_encoder_layer_fwd(


avg total loss:22.4946, avg exp loss:19.6103, avg_clip_loss:2.8843
Step 1, Loss: 2.5076, Exp_loss : 2.2283, Clip_loss ; 0.2794
Step 2, Loss: 2.8697, Exp_loss : 2.4382, Clip_loss ; 0.4315
Step 3, Loss: 2.5960, Exp_loss : 2.2989, Clip_loss ; 0.2971
Step 4, Loss: 2.6980, Exp_loss : 2.3872, Clip_loss ; 0.3108
Step 5, Loss: 2.8205, Exp_loss : 2.4475, Clip_loss ; 0.3730
Step 6, Loss: 2.8958, Exp_loss : 2.4384, Clip_loss ; 0.4574
Step 7, Loss: 2.6974, Exp_loss : 2.4088, Clip_loss ; 0.2887
Step 8, Loss: 2.7315, Exp_loss : 2.3962, Clip_loss ; 0.3353
Step 9, Loss: 2.7010, Exp_loss : 2.3229, Clip_loss ; 0.3782
Step 10, Loss: 2.6358, Exp_loss : 2.4014, Clip_loss ; 0.2343
Step 11, Loss: 2.8076, Exp_loss : 2.4119, Clip_loss ; 0.3957
Step 12, Loss: 2.6564, Exp_loss : 2.3106, Clip_loss ; 0.3459
Step 13, Loss: 2.7213, Exp_loss : 2.4565, Clip_loss ; 0.2647
Step 14, Loss: 2.7843, Exp_loss : 2.3770, Clip_loss ; 0.4073
Step 15, Loss: 2.4318, Exp_loss : 2.1357, Clip_loss ; 0.2961
Step 16, Loss: 2.7364, Exp_

Step 132, Loss: 2.6451, Exp_loss : 2.3108, Clip_loss ; 0.3343
Step 133, Loss: 2.9070, Exp_loss : 2.4650, Clip_loss ; 0.4420
Step 134, Loss: 2.6249, Exp_loss : 2.3457, Clip_loss ; 0.2793
Step 135, Loss: 2.6495, Exp_loss : 2.2659, Clip_loss ; 0.3836
Step 136, Loss: 2.8743, Exp_loss : 2.3653, Clip_loss ; 0.5090
Step 137, Loss: 2.5595, Exp_loss : 2.3013, Clip_loss ; 0.2582
Step 138, Loss: 2.8243, Exp_loss : 2.4221, Clip_loss ; 0.4022
Step 139, Loss: 2.6439, Exp_loss : 2.2616, Clip_loss ; 0.3823
Step 140, Loss: 2.5810, Exp_loss : 2.3032, Clip_loss ; 0.2778
Step 141, Loss: 2.7272, Exp_loss : 2.4363, Clip_loss ; 0.2909
Step 142, Loss: 2.7432, Exp_loss : 2.4025, Clip_loss ; 0.3407
Step 143, Loss: 2.4836, Exp_loss : 2.2723, Clip_loss ; 0.2113
Step 144, Loss: 2.6520, Exp_loss : 2.3988, Clip_loss ; 0.2532
Step 145, Loss: 2.6399, Exp_loss : 2.3365, Clip_loss ; 0.3033
Step 146, Loss: 2.6556, Exp_loss : 2.3228, Clip_loss ; 0.3329
Step 147, Loss: 2.6253, Exp_loss : 2.4036, Clip_loss ; 0.2217
Step 148

Step 82, Loss: 2.8446, Exp_loss : 2.4698, Clip_loss ; 0.3749
Step 83, Loss: 2.7495, Exp_loss : 2.3750, Clip_loss ; 0.3744
Step 84, Loss: 2.8736, Exp_loss : 2.3866, Clip_loss ; 0.4870
Step 85, Loss: 2.5561, Exp_loss : 2.3512, Clip_loss ; 0.2049
Step 86, Loss: 2.5063, Exp_loss : 2.2500, Clip_loss ; 0.2563
Step 87, Loss: 2.7271, Exp_loss : 2.3529, Clip_loss ; 0.3742
Step 88, Loss: 2.6173, Exp_loss : 2.3297, Clip_loss ; 0.2876
Step 89, Loss: 2.5170, Exp_loss : 2.2338, Clip_loss ; 0.2833
Step 90, Loss: 2.6362, Exp_loss : 2.2349, Clip_loss ; 0.4013
Step 91, Loss: 2.6478, Exp_loss : 2.3656, Clip_loss ; 0.2821
Step 92, Loss: 2.7435, Exp_loss : 2.3840, Clip_loss ; 0.3594
Step 93, Loss: 2.5412, Exp_loss : 2.3594, Clip_loss ; 0.1818
Step 94, Loss: 2.6562, Exp_loss : 2.3446, Clip_loss ; 0.3116
Step 95, Loss: 2.7277, Exp_loss : 2.4171, Clip_loss ; 0.3106
Step 96, Loss: 2.7912, Exp_loss : 2.4508, Clip_loss ; 0.3405
Step 97, Loss: 2.6467, Exp_loss : 2.3822, Clip_loss ; 0.2645
Step 98, Loss: 2.6962, E

KeyboardInterrupt: 

In [8]:
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"Device {i}: {torch.cuda.get_device_name(i)}")

Device 0: NVIDIA RTX A6000
Device 1: NVIDIA RTX A6000
