In [1]:
import time
t0start = time.time()
from fastai.collab import *
from fastai.tabular.all import *



## Set Random Seeds

In [2]:
def random_seed(seed_value, use_cuda):
    np.random.seed(seed_value) # for numpy random
    torch.manual_seed(seed_value) # for pytorch
    random.seed(seed_value) # for python random
    if use_cuda:
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
seed_value=42
random_seed(seed_value, use_cuda=True)

## Read Data

Here I read the training and test data and melt it to yield a ```DataFrame``` with three categorical features (```cell_type```, ```sm_name```, and ```gene```) and one target (```value```).

In [3]:
%%time
fn = '/kaggle/input/open-problems-single-cell-perturbations/de_train.parquet'
df_de_train = pd.read_parquet(fn)
print(df_de_train.shape)
df_de_train

(614, 18216)
CPU times: user 1.95 s, sys: 411 ms, total: 2.37 s
Wall time: 2.38 s


Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.104720,-0.077524,-1.625596,-0.144545,0.143555,...,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.884380,0.371834,-0.081677,-0.498266,...,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.704780,1.096702,-0.869887
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,...,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.213550,0.415768,0.078439,-0.259365
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,...,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.224700,-0.048233,0.216139,-0.085024
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C@@]3(Cl)[C@@H](O)C[C@]2(C)[C@@]1(OC(=O)c1ccco1)C(=O)CCl,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,...,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
609,T regulatory cells,Atorvastatin,LSM-5771,CC(C)c1c(C(=O)Nc2ccccc2)c(-c2ccccc2)c(-c2ccc(F)cc2)n1CC[C@@H](O)C[C@@H](O)CC(=O)O,False,-0.014372,-0.122464,-0.456366,-0.147894,-0.545382,...,-0.549987,-2.200925,0.359806,1.073983,0.356939,-0.029603,-0.528817,0.105138,0.491015,-0.979951
610,NK cells,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23)nc1N,False,-0.455549,0.188181,0.595734,-0.100299,0.786192,...,-1.236905,0.003854,-0.197569,-0.175307,0.101391,1.028394,0.034144,-0.231642,1.023994,-0.064760
611,T cells CD4+,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23)nc1N,False,0.338168,-0.109079,0.270182,-0.436586,-0.069476,...,0.077579,-1.101637,0.457201,0.535184,-0.198404,-0.005004,0.552810,-0.209077,0.389751,-0.337082
612,T cells CD8+,Riociguat,LSM-45758,COC(=O)N(C)c1c(N)nc(-c2nn(Cc3ccccc3F)c3ncccc23)nc1N,False,0.101138,-0.409724,-0.606292,-0.071300,-0.001789,...,0.005951,-0.893093,-1.003029,-0.080367,-0.076604,0.024849,0.012862,-0.029684,0.005506,-1.733112


In [4]:
train_df_melted = df_de_train.melt(id_vars=['cell_type', 'sm_name'], value_vars=df_de_train.iloc[:,5:].columns, var_name='gene', value_name='value')
train_df_melted

Unnamed: 0,cell_type,sm_name,gene,value
0,NK cells,Clotrimazole,A1BG,0.104720
1,T cells CD4+,Clotrimazole,A1BG,0.915953
2,T cells CD8+,Clotrimazole,A1BG,-0.387721
3,T regulatory cells,Clotrimazole,A1BG,0.232893
4,NK cells,Mometasone Furoate,A1BG,4.290652
...,...,...,...,...
11181549,T regulatory cells,Atorvastatin,ZZEF1,-0.979951
11181550,NK cells,Riociguat,ZZEF1,-0.064760
11181551,T cells CD4+,Riociguat,ZZEF1,-0.337082
11181552,T cells CD8+,Riociguat,ZZEF1,-1.733112


In [5]:
fn = '/kaggle/input/open-problems-single-cell-perturbations/id_map.csv'
df_id_map = pd.read_csv(fn, index_col=0)
cols_to_add = df_de_train.iloc[:,5:].columns
df_zeros = pd.DataFrame(0.0, columns=cols_to_add, index=df_id_map.index)
test_df = pd.concat([df_id_map, df_zeros], axis=1)
test_df

Unnamed: 0_level_0,cell_type,sm_name,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,B cells,5-(9-Isopropyl-8-methyl-2-morpholino-9H-purin-6-yl)pyrimidin-2-amine,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,B cells,ABT-199 (GDC-0199),0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,B cells,ABT737,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,B cells,AMD-070 (hydrochloride),0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,B cells,AT 7867,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,Myeloid cells,Vandetanib,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
251,Myeloid cells,Vanoxerine,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
252,Myeloid cells,Vardenafil,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
253,Myeloid cells,Vorinostat,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [6]:
test_df_melted = test_df.melt(id_vars=['cell_type', 'sm_name'], value_vars=test_df.iloc[:,2:].columns, var_name='gene', value_name='value')
test_df_melted

Unnamed: 0,cell_type,sm_name,gene,value
0,B cells,5-(9-Isopropyl-8-methyl-2-morpholino-9H-purin-6-yl)pyrimidin-2-amine,A1BG,0.0
1,B cells,ABT-199 (GDC-0199),A1BG,0.0
2,B cells,ABT737,A1BG,0.0
3,B cells,AMD-070 (hydrochloride),A1BG,0.0
4,B cells,AT 7867,A1BG,0.0
...,...,...,...,...
4643800,Myeloid cells,Vandetanib,ZZEF1,0.0
4643801,Myeloid cells,Vanoxerine,ZZEF1,0.0
4643802,Myeloid cells,Vardenafil,ZZEF1,0.0
4643803,Myeloid cells,Vorinostat,ZZEF1,0.0


## Define Functions

The function ```get_denoised_data``` takes the original training data, as well as the melted ```DataFrame``` and returns a melted ```DataFrame``` that has been denoised using PCA and ```n_comp``` components. It also randomly selects and marks a fraction of ```0.2``` of the training set as validation data and replaces that data with the original non-denoised data.

In [7]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

def get_denoised_data(data, data_melted, n_comp=35):
    Y = data.iloc[:,5:]

    # Standardize the data
    scaler = StandardScaler()
    Y_std = scaler.fit_transform(Y)

    reducer = PCA(n_components=n_comp, random_state=seed_value)

    Yr_std = reducer.fit_transform(Y_std)
    Yr_inv_trans_sdt = reducer.inverse_transform(Yr_std)
    Yr_inv_trans_denorm = scaler.inverse_transform(Yr_inv_trans_sdt)
    df_red_inv_trans = pd.DataFrame(Yr_inv_trans_denorm, columns = data.columns[5:])
    df_de_train_denoised = pd.concat([data.iloc[:, :2], df_red_inv_trans], axis=1)
    train_denoised_melted = df_de_train_denoised.melt(id_vars=['cell_type', 'sm_name'], value_vars=df_de_train_denoised.iloc[:,2:].columns, var_name='gene', value_name='value')
    train_denoised_melted['is_valid'] = False
    valid_indices = data_melted.sample(frac=0.2, random_state=seed_value).index.sort_values().tolist()
    train_denoised_melted.loc[valid_indices, 'value'] = data_melted.loc[valid_indices, 'value']
    train_denoised_melted.loc[valid_indices, 'is_valid'] = True
    return train_denoised_melted

This splitter function is fed into fast.ai's ```TabularPandas``` to get the test-valid-split.

In [8]:
def splitter(df):
    train = df.index[~df['is_valid']].tolist()
    valid = df.index[df['is_valid']].tolist()
    return L(train), L(valid)

The function ```get_dls``` takes the original training data and the melted ```DataFrame```, applies ```get_denoised_data```, calls fast.ai's ```TabularPandas``` to get the training data into the correct format and returns the batched data loaders to feed to the model.

In [9]:
def get_dls(data, data_melted, n_comp=35, bs=4096):
    train_denoised_melted = get_denoised_data(data, data_melted, n_comp=n_comp)
    cats = train_denoised_melted.iloc[:, :3].columns.to_list()
    splits = splitter(train_denoised_melted)
    to = TabularPandas(train_denoised_melted, procs=[Categorify], cat_names=cats, y_names='value', splits=splits)
    dls = to.dataloaders(bs)
    return dls

In [10]:
def rmse(preds, targs):
    return ((targs-preds)**2).mean().sqrt().item()

## Train Tabular Learner with PCA Denoising Using 10 Components

In [11]:
%%time
dls = get_dls(df_de_train, train_df_melted, n_comp=10, bs=4096)
dls.valid.show()

Unnamed: 0,cell_type,sm_name,gene,value
0,NK cells,Clotrimazole,A1BG,0.10472
2,T cells CD8+,Clotrimazole,A1BG,-0.387721
11,T cells CD4+,Idelalisib,A1BG,0.206471
12,T cells CD8+,Idelalisib,A1BG,0.046959
13,T regulatory cells,Idelalisib,A1BG,0.210456
17,T regulatory cells,Vandetanib,A1BG,-0.077971
22,NK cells,Ceritinib,A1BG,-1.122906
23,T cells CD4+,Ceritinib,A1BG,-0.14371
26,NK cells,Lamivudine,A1BG,0.407019
28,T cells CD8+,Lamivudine,A1BG,0.150796


CPU times: user 14.5 s, sys: 3.33 s, total: 17.8 s
Wall time: 15.1 s


In [12]:
dls.train.show()

Unnamed: 0,cell_type,sm_name,gene,value
1,T cells CD4+,Clotrimazole,A1BG,0.008733
3,T regulatory cells,Clotrimazole,A1BG,0.436567
4,NK cells,Mometasone Furoate,A1BG,0.599439
5,T cells CD4+,Mometasone Furoate,A1BG,0.204235
6,T cells CD8+,Mometasone Furoate,A1BG,0.162474
7,T regulatory cells,Mometasone Furoate,A1BG,0.824737
8,B cells,Idelalisib,A1BG,0.234715
9,Myeloid cells,Idelalisib,A1BG,0.354405
10,NK cells,Idelalisib,A1BG,0.007681
14,NK cells,Vandetanib,A1BG,-0.037544


Here I define an ```emb_szs``` dictionary to change the default dimension of ```gene``` embeddings from ```389``` to ```600```:

In [13]:
emb_szs = {'cell_type': 5, 'sm_name': 26, 'gene': 600}
get_emb_sz(dls.train_ds)

[(7, 5), (147, 26), (18212, 389)]

The next code block is to find the min and max values of the target, because the ```tabular_learner``` tends to learn better if the target range (```y_range```) is provided.

In [14]:
y = dls.train.y
y_min = np.ceil(y.min()).astype(int)
y_max = np.ceil(y.max()).astype(int)
y_min, y_max

(-189, 186)

In [15]:
learn = tabular_learner(dls, y_range=(y_min, y_max), emb_szs=emb_szs, layers=[1000, 500, 250], n_out=1, loss_func=F.mse_loss)
learn.fit_one_cycle(20, 3e-4)

epoch,train_loss,valid_loss,time
0,15.014743,15.060467,01:06
1,4.043351,4.640964,01:05
2,1.480463,2.147437,01:05
3,0.814333,1.504443,01:05
4,0.473158,1.572463,01:06
5,0.263745,1.098528,01:07
6,0.191113,1.072918,01:07
7,0.144829,0.891821,01:06
8,0.118946,1.314085,01:06
9,0.095639,1.066761,01:07


In [16]:
preds, targs = learn.get_preds()
rmse(preds, targs)

0.8800662159919739

## Submission

In [17]:
test_dl = dls.test_dl(test_df_melted)
preds, _ = learn.get_preds(dl=test_dl)
preds

tensor([[ 0.3602],
        [ 0.1060],
        [ 0.4222],
        ...,
        [-0.1846],
        [-0.3003],
        [-0.2014]])

In the last step I reshape the predictions back into a ```255 x 18211``` tensor for submission:

In [18]:
to_submit = preds.view(18211, -1).t().numpy()
submit = pd.DataFrame(to_submit, columns=df_de_train.iloc[:,5:].columns)
submit.index.name = 'id'
submit

Unnamed: 0_level_0,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,AAK1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0.360184,-0.070511,0.929810,0.873932,1.260040,0.187805,-0.146988,0.209686,-0.485550,1.086548,...,-0.250732,-0.043564,-0.119598,-0.017654,0.387466,0.153000,0.385727,0.539215,-0.138779,0.288895
1,0.105972,-0.033234,0.023071,0.232880,0.832367,0.504883,-0.186371,0.054001,-0.019928,0.112091,...,-0.072083,-0.048248,-0.254120,0.146942,-0.030045,-0.162277,0.043091,-0.009720,-0.090500,-0.119598
2,0.422165,-0.019440,0.278885,0.291199,1.682907,1.522385,-0.228607,0.170013,-0.101028,0.266876,...,0.026978,-0.040451,-0.129318,0.302444,0.144485,-0.006607,0.100922,0.107742,-0.109390,-0.030548
3,0.032257,-0.045059,0.091110,0.159332,0.393356,0.195847,-0.090057,0.047318,-0.059570,0.102859,...,-0.148331,-0.010101,-0.134262,0.024033,0.109055,-0.027756,0.145935,0.073471,-0.153458,-0.096085
4,0.186768,0.026779,0.468155,0.446732,0.955368,0.465897,-0.078964,0.146744,-0.116379,0.507172,...,-0.150177,0.042801,-0.118835,0.093002,0.270615,0.134918,0.230118,0.195496,-0.149109,-0.040588
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,0.209412,0.100159,-0.496902,-0.150070,0.833099,0.210938,0.086868,0.019852,0.081451,0.050964,...,0.245667,0.127625,-0.198608,0.154068,0.330109,0.272522,0.291382,0.062729,-0.062851,-0.148285
251,0.296371,0.089966,-0.596054,-0.039673,1.738449,1.159119,-0.045303,0.219452,-0.105072,0.254639,...,-0.070801,-0.015381,-0.345245,-0.000885,0.362686,0.136230,0.134766,-0.040207,-0.151855,-0.141327
252,0.146378,0.062943,-0.700897,-0.043854,1.378372,0.294189,-0.028824,0.089050,-0.130981,0.167038,...,0.002411,-0.129883,-0.404449,-0.061111,0.145493,-0.015976,-0.030701,0.001053,-0.043365,-0.184555
253,0.492279,0.757477,-2.275116,-0.019302,3.406982,2.113602,0.265747,0.464630,0.326874,0.361893,...,-0.039780,0.047638,-1.358688,0.094177,0.118530,0.083313,-0.625778,-0.059479,-0.058380,-0.300339


In [19]:
submit.to_csv('submission.csv')