In [1]:
import pandas as pd
import numpy as np
import bcolz as bz
import torch
from torch import nn

In [2]:
from fastai.vision import *

In [3]:
SIZE = 456
SITE = 3 # Site: 1:site1, 2:site2, 3:site1 and 2

LR = 1e-5
BS = 64
EPOCHS = 100

In [4]:
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader

In [5]:
train_bc = bz.open("/data/rcic/actv_train")

In [6]:
arr = train_bc[[2,5,7,8]]

In [7]:
train_bc.shape

(73024, 2048)

In [8]:
train_bc.flush()

In [9]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    

SEED = 0
seed_everything(SEED)


In [10]:
from pathlib import Path
print("Loading training dataframe")
DATA = Path("/mnt/disk4/cell/")

train_df = pd.read_csv(DATA/'train.csv')
train_df.head(10)

Loading training dataframe


Unnamed: 0,id_code,experiment,plate,well,sirna
0,HEPG2-01_1_B03,HEPG2-01,1,B03,513
1,HEPG2-01_1_B04,HEPG2-01,1,B04,840
2,HEPG2-01_1_B05,HEPG2-01,1,B05,1020
3,HEPG2-01_1_B06,HEPG2-01,1,B06,254
4,HEPG2-01_1_B07,HEPG2-01,1,B07,144
5,HEPG2-01_1_B08,HEPG2-01,1,B08,503
6,HEPG2-01_1_B09,HEPG2-01,1,B09,188
7,HEPG2-01_1_B10,HEPG2-01,1,B10,700
8,HEPG2-01_1_B11,HEPG2-01,1,B11,1100
9,HEPG2-01_1_B12,HEPG2-01,1,B12,611


In [11]:
groups = np.load("groups.npy")

In [12]:
def generate_df(train_df,sample_num=1):
    train_df['path'] = train_df['experiment'].str.cat(train_df['plate'].astype(str).str.cat(train_df['well'],sep='/'),sep='/Plate') + '_s'+str(sample_num) + '_w'
    train_df["pname"] = train_df.apply(lambda x:x["experiment"]+"-"+x["plate"].__str__(), axis=1)
#     train_df["grp"] = train_df.sirna.apply(lambda x:groups[x])
    train_df = train_df.drop(columns=['id_code','experiment','plate','well']).reindex(columns=['path','sirna',"pname"])
    
    return train_df

In [13]:
site1_train_df = generate_df(train_df)  
site2_train_df = generate_df(train_df, sample_num=2)

if SITE==1: # only site1
    proc_train_df = site1_train_df 
elif SITE==2 : # only site2
    proc_train_df = site2_train_df
elif SITE==3 :
    proc_train_df = pd.concat([site1_train_df,site2_train_df],axis=0 ).reset_index().drop("index",axis=1).head(73024)
#     proc_train_df.to_csv("train_with_bc.csv")

In [29]:
act_df = proc_train_df.reset_index().rename(columns = {"index":"actid"})

In [30]:
act_df["grp"]=act_df["sirna"].apply(lambda x:groups[x])

In [32]:
split_ = (np.random.rand(len(act_df))<0.965)

In [33]:
train_act_df = act_df[split_]
val_act_df = act_df[~split_]
len(train_act_df),len(val_act_df)

(70390, 2634)

In [34]:
import math
class actDs(Dataset):
    def __init__(self,df,ba,bs=BS):
        self.df = df.sample(frac = 1.).reset_index().drop("index",axis=1)
        self.ba = ba
        self.bs = bs
        
    def __len__(self):
        return math.ceil(float(len(self.df))/float(self.bs))
    
    def __getitem__(self,idx):
        df_ = self.df[idx*self.bs:(idx+1)*self.bs]
        act_arr = self.ba[list(df_.actid)]
        self.ba.flush()
        return act_arr,df_.sirna.values, df_.grp.values

In [35]:
train_ds = actDs(train_act_df,train_bc)
val_ds = actDs(val_act_df,train_bc)

In [18]:
train_dl = DataLoader(train_ds,batch_size=1, num_workers=4,)

In [36]:
from forgebox.ftorch.train import Trainer
from forgebox.ftorch.callbacks import stat

In [37]:
class aLearner(nn.Module):
    def __init__(self,):
        super().__init__()
        self.top = nn.Sequential(*[
            nn.Linear(2048,2048, bias=False),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Linear(2048,1108, bias=False),
            nn.BatchNorm1d(1108),
            nn.Softmax(),])
        
    def forward(self,x):
        return self.top(x)        

In [38]:
CUDA = torch.cuda.is_available()

In [39]:
model = aLearner().train()
if CUDA:
    model.cuda()

In [40]:
from forgebox.ftorch.metrics import accuracy

In [41]:
opt = torch.optim.Adam(model.parameters(), lr=1e-4)

In [42]:
crit = nn.CrossEntropyLoss()

In [43]:
t = Trainer(train_ds, val_dataset=val_ds,
            batch_size=1, 
            num_workers=4, 
            callbacks=[stat], val_callbacks=[stat])


In [44]:
@t.step_train
def action(batch):
    if batch.i ==0:
        model.train()
    opt.zero_grad()
    x,y1,y2 = batch.data
    x  =x[0].float(); y1 = y1[0]
    if CUDA:
        x = x.cuda()
        y1 = y1.cuda()
        y2 = y2.cuda()
        
    y1_ = model(x)
    loss = crit(y1_,y1)
    acc = accuracy(y1_,y1)
    loss.backward()
    opt.step()
    return {"loss":loss.item(),"acc":acc.item()}

@t.step_val
def val_action(batch):
    if batch.i ==0:
        model.eval()
#     opt.zero_grad()
    x,y1,y2 = batch.data
    x  =x[0].float(); y1 = y1[0]
    if CUDA:
        x = x.cuda()
        y1 = y1.cuda()
        y2 = y2.cuda()
        
    y1_ = model(x)
    loss = crit(y1_,y1)
    acc = accuracy(y1_,y1)
#     loss.backward()
#     opt.step()
    return {"loss":loss.item(),"acc":acc.item()}

In [45]:
t.train(50)

HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))

  input = module(input)





Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.949199,0.220131,0.0,549.5,0.018473
min,6.856333,0.0,0.0,0.0,0.0
max,7.010438,0.53125,0.0,1099.0,20.320372


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.732979,0.411979,0.0,20.5,0.017826
min,6.592898,0.1,0.0,0.0,0.0
max,6.996486,0.59375,0.0,41.0,0.748686


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.822542,0.511285,1.0,549.5,0.016149
min,6.730528,0.171875,1.0,0.0,0.0
max,6.982502,0.703125,1.0,1099.0,17.763589


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.601595,0.552976,1.0,20.5,0.018113
min,6.458587,0.1,1.0,0.0,0.0
max,6.986842,0.75,1.0,41.0,0.760735


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.728081,0.630264,2.0,549.5,0.019775
min,6.627197,0.21875,2.0,0.0,0.0
max,6.985244,0.8125,2.0,1099.0,21.75243


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.535044,0.618973,2.0,20.5,0.019467
min,6.381725,0.2,2.0,0.0,0.0
max,6.969682,0.78125,2.0,41.0,0.817617


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.655589,0.694378,3.0,549.5,0.016098
min,6.544094,0.296875,3.0,0.0,0.0
max,6.957313,0.890625,3.0,1099.0,17.708171


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.494779,0.66875,3.0,20.5,0.019815
min,6.330268,0.4,3.0,0.0,0.0
max,6.923688,0.859375,3.0,41.0,0.832248


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.59604,0.736003,4.0,549.5,0.019904
min,6.479646,0.4375,4.0,0.0,0.0
max,6.941741,0.9375,4.0,1099.0,21.894903


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.46181,0.691443,4.0,20.5,0.017575
min,6.284893,0.4,4.0,0.0,0.0
max,6.822679,0.859375,4.0,41.0,0.738162


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.545721,0.766948,5.0,549.5,0.016791
min,6.426653,0.421875,5.0,0.0,0.0
max,6.945586,0.9375,5.0,1099.0,18.470564


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.438757,0.709821,5.0,20.5,0.034152
min,6.265386,0.5,5.0,0.0,0.0
max,6.804122,0.875,5.0,41.0,1.434388


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.501842,0.78861,6.0,549.5,0.025535
min,6.347242,0.40625,6.0,0.0,0.0
max,6.932436,0.96875,6.0,1099.0,28.088964


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.422927,0.729167,6.0,20.5,0.015303
min,6.255398,0.5,6.0,0.0,0.0
max,6.790184,0.859375,6.0,41.0,0.64271


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.461634,0.805545,7.0,549.5,0.019447
min,6.331493,0.390625,7.0,0.0,0.0
max,6.949885,0.96875,7.0,1099.0,21.391354


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.410235,0.74256,7.0,20.5,0.024117
min,6.249838,0.5,7.0,0.0,0.0
max,6.728755,0.84375,7.0,41.0,1.012896


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.42645,0.819465,8.0,549.5,0.020581
min,6.302649,0.421875,8.0,0.0,0.0
max,6.949191,0.984375,8.0,1099.0,22.638669


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.396072,0.745759,8.0,20.5,0.021784
min,6.243227,0.4,8.0,0.0,0.0
max,6.754444,0.875,8.0,41.0,0.914909


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.393808,0.832607,9.0,549.5,0.017218
min,6.264248,0.484375,9.0,0.0,0.0
max,6.874777,0.984375,9.0,1099.0,18.939913


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.396832,0.756324,9.0,20.5,0.019838
min,6.246338,0.5,9.0,0.0,0.0
max,6.706263,0.859375,9.0,41.0,0.833202


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.364287,0.842965,10.0,549.5,0.020019
min,6.249118,0.5,10.0,0.0,0.0
max,6.920336,1.0,10.0,1099.0,22.020751


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.389602,0.764286,10.0,20.5,0.016121
min,6.238245,0.6,10.0,0.0,0.0
max,6.736875,0.875,10.0,41.0,0.677066


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.338079,0.852072,11.0,549.5,0.015975
min,6.219151,0.453125,11.0,0.0,0.0
max,6.905402,1.0,11.0,1099.0,17.57293


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.373588,0.772247,11.0,20.5,0.019898
min,6.226645,0.640625,11.0,0.0,0.0
max,6.66964,0.859375,11.0,41.0,0.83571


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.313415,0.860896,12.0,549.5,0.017372
min,6.200775,0.453125,12.0,0.0,0.0
max,6.878119,1.0,12.0,1099.0,19.109647


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.367992,0.775223,12.0,20.5,0.021385
min,6.222563,0.625,12.0,0.0,0.0
max,6.648908,0.875,12.0,41.0,0.89815


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.291066,0.868555,13.0,549.5,0.018476
min,6.178883,0.40625,13.0,0.0,0.0
max,6.88198,1.0,13.0,1099.0,20.323079


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.36519,0.777827,13.0,20.5,0.023444
min,6.229288,0.671875,13.0,0.0,0.0
max,6.661317,0.875,13.0,41.0,0.984649


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.270267,0.876436,14.0,549.5,0.020849
min,6.153095,0.515625,14.0,0.0,0.0
max,6.817224,1.0,14.0,1099.0,22.934204


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.352068,0.787872,14.0,20.5,0.020972
min,6.212526,0.671875,14.0,0.0,0.0
max,6.519842,0.875,14.0,41.0,0.880804


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.251928,0.883214,15.0,549.5,0.017443
min,6.139864,0.515625,15.0,0.0,0.0
max,6.851214,1.0,15.0,1099.0,19.186797


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.346967,0.784152,15.0,20.5,0.020051
min,6.212479,0.671875,15.0,0.0,0.0
max,6.530759,0.875,15.0,41.0,0.842133


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.234262,0.889137,16.0,549.5,0.014693
min,6.124493,0.703125,16.0,0.0,0.0
max,6.78018,1.0,16.0,1099.0,16.162711


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.343631,0.790476,16.0,20.5,0.027118
min,6.213178,0.671875,16.0,0.0,0.0
max,6.526496,0.890625,16.0,41.0,1.138973


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.218178,0.894989,17.0,549.5,0.016525
min,6.104781,0.734375,17.0,0.0,0.0
max,6.677335,1.0,17.0,1099.0,18.177289


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.340065,0.797917,17.0,20.5,0.027928
min,6.215188,0.6875,17.0,0.0,0.0
max,6.518121,0.890625,17.0,41.0,1.172958


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.203911,0.900373,18.0,549.5,0.016809
min,6.097449,0.734375,18.0,0.0,0.0
max,6.615003,1.0,18.0,1099.0,18.489406


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.337189,0.805729,18.0,20.5,0.037609
min,6.203866,0.7,18.0,0.0,0.0
max,6.539531,0.921875,18.0,41.0,1.57957


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.189975,0.905643,19.0,549.5,0.01821
min,6.083302,0.765625,19.0,0.0,0.0
max,6.585424,1.0,19.0,1099.0,20.031322


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.349779,0.797917,19.0,20.5,0.018834
min,6.222229,0.6875,19.0,0.0,0.0
max,6.59286,0.890625,19.0,41.0,0.791021


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.177751,0.91016,20.0,549.5,0.020339
min,6.074153,0.765625,20.0,0.0,0.0
max,6.543071,1.0,20.0,1099.0,22.373084


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.325817,0.799033,20.0,20.5,0.018871
min,6.211692,0.6875,20.0,0.0,0.0
max,6.530685,0.890625,20.0,41.0,0.792584


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.166076,0.914921,21.0,549.5,0.018077
min,6.067479,0.78125,21.0,0.0,0.0
max,6.426173,1.0,21.0,1099.0,19.884393


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.332632,0.805357,21.0,20.5,0.020581
min,6.214358,0.671875,21.0,0.0,0.0
max,6.554402,0.890625,21.0,41.0,0.864388


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.156072,0.919208,22.0,549.5,0.019144
min,6.059338,0.78125,22.0,0.0,0.0
max,6.394172,1.0,22.0,1099.0,21.05826


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.319127,0.806101,22.0,20.5,0.018462
min,6.198081,0.7,22.0,0.0,0.0
max,6.549329,0.890625,22.0,41.0,0.77539


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.146138,0.923413,23.0,549.5,0.017545
min,6.052316,0.796875,23.0,0.0,0.0
max,6.391279,1.0,23.0,1099.0,19.299901


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.333318,0.803497,23.0,20.5,0.016107
min,6.20569,0.7,23.0,0.0,0.0
max,6.568433,0.890625,23.0,41.0,0.676484


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.137347,0.926583,24.0,549.5,0.017858
min,6.049476,0.796875,24.0,0.0,0.0
max,6.327239,1.0,24.0,1099.0,19.643628


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.320931,0.805208,24.0,20.5,0.021715
min,6.203773,0.6,24.0,0.0,0.0
max,6.527667,0.890625,24.0,41.0,0.912028


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.129528,0.930529,25.0,549.5,0.017003
min,6.045975,0.796875,25.0,0.0,0.0
max,6.307727,1.0,25.0,1099.0,18.702971


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.326632,0.801637,25.0,20.5,0.020514
min,6.200823,0.7,25.0,0.0,0.0
max,6.548907,0.890625,25.0,41.0,0.861596


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.121802,0.934029,26.0,549.5,0.015162
min,6.044334,0.796875,26.0,0.0,0.0
max,6.249592,1.0,26.0,1099.0,16.67848


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.31945,0.803125,26.0,20.5,0.030018
min,6.209145,0.7,26.0,0.0,0.0
max,6.519995,0.890625,26.0,41.0,1.260746


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.114778,0.937734,27.0,549.5,0.016786
min,6.039081,0.796875,27.0,0.0,0.0
max,6.238684,1.0,27.0,1099.0,18.465009


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.318436,0.803348,27.0,20.5,0.021426
min,6.197271,0.6,27.0,0.0,0.0
max,6.561569,0.890625,27.0,41.0,0.899892


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.108098,0.941029,28.0,549.5,0.020839
min,6.037891,0.828125,28.0,0.0,0.0
max,6.224783,1.0,28.0,1099.0,22.922713


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.319746,0.806101,28.0,20.5,0.016955
min,6.203893,0.7,28.0,0.0,0.0
max,6.494842,0.890625,28.0,41.0,0.712113


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.102303,0.943944,29.0,549.5,0.015725
min,6.034472,0.828125,29.0,0.0,0.0
max,6.213485,1.0,29.0,1099.0,17.297606


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.313543,0.803125,29.0,20.5,0.024804
min,6.190507,0.7,29.0,0.0,0.0
max,6.494826,0.890625,29.0,41.0,1.041771


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.096852,0.946569,30.0,549.5,0.016646
min,6.032105,0.828125,30.0,0.0,0.0
max,6.221531,1.0,30.0,1099.0,18.310949


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.315779,0.810342,30.0,20.5,0.016878
min,6.198049,0.71875,30.0,0.0,0.0
max,6.488978,0.90625,30.0,41.0,0.708857


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.092065,0.949157,31.0,549.5,0.018247
min,6.02883,0.828125,31.0,0.0,0.0
max,6.203898,1.0,31.0,1099.0,20.071911


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.319645,0.81183,31.0,20.5,0.033218
min,6.200032,0.71875,31.0,0.0,0.0
max,6.55112,0.890625,31.0,41.0,1.395143


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.087731,0.951813,32.0,549.5,0.019351
min,6.027898,0.828125,32.0,0.0,0.0
max,6.204471,1.0,32.0,1099.0,21.286316


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.319697,0.807961,32.0,20.5,0.023279
min,6.199957,0.7,32.0,0.0,0.0
max,6.531246,0.890625,32.0,41.0,0.977703


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.083454,0.954341,33.0,549.5,0.01781
min,6.027102,0.828125,33.0,0.0,0.0
max,6.198711,1.0,33.0,1099.0,19.590791


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.307635,0.805878,33.0,20.5,0.017978
min,6.193252,0.71875,33.0,0.0,0.0
max,6.505468,0.90625,33.0,41.0,0.755065


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.079052,0.95633,34.0,549.5,0.015344
min,6.025247,0.828125,34.0,0.0,0.0
max,6.199079,1.0,34.0,1099.0,16.878339


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.315454,0.80625,34.0,20.5,0.025399
min,6.196201,0.734375,34.0,0.0,0.0
max,6.521132,0.890625,34.0,41.0,1.066754


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.075104,0.958262,35.0,549.5,0.018504
min,6.023255,0.828125,35.0,0.0,0.0
max,6.196835,1.0,35.0,1099.0,20.354691


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.311813,0.803125,35.0,20.5,0.032179
min,6.198282,0.7,35.0,0.0,0.0
max,6.552401,0.921875,35.0,41.0,1.351526


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.071866,0.960563,36.0,549.5,0.017207
min,6.022419,0.828125,36.0,0.0,0.0
max,6.195935,1.0,36.0,1099.0,18.927194


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.303258,0.810491,36.0,20.5,0.024964
min,6.191872,0.703125,36.0,0.0,0.0
max,6.479833,0.90625,36.0,41.0,1.048502


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.068533,0.962282,37.0,549.5,0.016561
min,6.020772,0.828125,37.0,0.0,0.0
max,6.193917,1.0,37.0,1099.0,18.216604


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.30881,0.81131,37.0,20.5,0.015063
min,6.19741,0.7,37.0,0.0,0.0
max,6.488569,0.9375,37.0,41.0,0.632631


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.065369,0.964029,38.0,549.5,0.019429
min,6.020881,0.828125,38.0,0.0,0.0
max,6.192668,1.0,38.0,1099.0,21.371861


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.310102,0.806399,38.0,20.5,0.024685
min,6.197848,0.703125,38.0,0.0,0.0
max,6.492497,0.921875,38.0,41.0,1.03677


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.062552,0.965705,39.0,549.5,0.017059
min,6.019865,0.84375,39.0,0.0,0.0
max,6.191917,1.0,39.0,1099.0,18.764696


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.300151,0.808631,39.0,20.5,0.023695
min,6.1837,0.734375,39.0,0.0,0.0
max,6.445326,0.9375,39.0,41.0,0.995194


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.060019,0.966998,40.0,549.5,0.015723
min,6.019451,0.828125,40.0,0.0,0.0
max,6.1904,1.0,40.0,1099.0,17.295539


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.305227,0.80997,40.0,20.5,0.017398
min,6.192286,0.75,40.0,0.0,0.0
max,6.51299,0.90625,40.0,41.0,0.73072


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.058066,0.968549,41.0,549.5,0.017887
min,6.018453,0.84375,41.0,0.0,0.0
max,6.189762,1.0,41.0,1099.0,19.675268


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.301429,0.807589,41.0,20.5,0.02365
min,6.194448,0.7,41.0,0.0,0.0
max,6.520194,0.90625,41.0,41.0,0.99328


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.055612,0.969969,42.0,549.5,0.01754
min,6.017817,0.859375,42.0,0.0,0.0
max,6.181339,1.0,42.0,1099.0,19.294234


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.320459,0.806473,42.0,20.5,0.025524
min,6.215542,0.7,42.0,0.0,0.0
max,6.525511,0.90625,42.0,41.0,1.072025


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.053308,0.971333,43.0,549.5,0.017077
min,6.017522,0.875,43.0,0.0,0.0
max,6.165321,1.0,43.0,1099.0,18.784872


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.31755,0.807589,43.0,20.5,0.02184
min,6.205634,0.7,43.0,0.0,0.0
max,6.577052,0.90625,43.0,41.0,0.917277


HBox(children=(IntProgress(value=0, max=1100), HTML(value='')))




Unnamed: 0,loss,acc,epoch,iter,timestamp
mean,6.051644,0.972512,44.0,549.5,0.019045
min,6.01734,0.890625,44.0,0.0,0.0
max,6.153747,1.0,44.0,1099.0,20.949803


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))

Exception in thread Thread-186:
Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/opt/anaconda3/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/anaconda3/lib/python3.6/site-packages/torch/utils/data/_utils/pin_memory.py", line 21, in _pin_memory_loop
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/opt/anaconda3/lib/python3.6/multiprocessing/queues.py", line 113, in get
    return _ForkingPickler.loads(res)
  File "/opt/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 284, in rebuild_storage_fd
    fd = df.detach()
  File "/opt/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/opt/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 87, in get_connection
    c = Client(address, authke

KeyboardInterrupt: 

In [91]:
test_df = pd.read_csv(DATA/'test.csv')
test_df.head(10)

Unnamed: 0,id_code,experiment,plate,well
0,HEPG2-08_1_B03,HEPG2-08,1,B03
1,HEPG2-08_1_B04,HEPG2-08,1,B04
2,HEPG2-08_1_B05,HEPG2-08,1,B05
3,HEPG2-08_1_B06,HEPG2-08,1,B06
4,HEPG2-08_1_B07,HEPG2-08,1,B07
5,HEPG2-08_1_B08,HEPG2-08,1,B08
6,HEPG2-08_1_B09,HEPG2-08,1,B09
7,HEPG2-08_1_B10,HEPG2-08,1,B10
8,HEPG2-08_1_B11,HEPG2-08,1,B11
9,HEPG2-08_1_B12,HEPG2-08,1,B12


In [92]:
test1_bc = bz.open("/data/rcic/actv_test1")

In [93]:
test1_bc.shape

(19897, 2048)

In [94]:
test1_df = generate_df(test_df, sample_num=1).reset_index().rename(columns={"index":"actid"})

In [95]:
test1_df

Unnamed: 0,actid,path,sirna,pname
0,0,HEPG2-08/Plate1/B03_s1_w,,HEPG2-08-1
1,1,HEPG2-08/Plate1/B04_s1_w,,HEPG2-08-1
2,2,HEPG2-08/Plate1/B05_s1_w,,HEPG2-08-1
3,3,HEPG2-08/Plate1/B06_s1_w,,HEPG2-08-1
4,4,HEPG2-08/Plate1/B07_s1_w,,HEPG2-08-1
...,...,...,...,...
19892,19892,U2OS-05/Plate4/O19_s1_w,,U2OS-05-4
19893,19893,U2OS-05/Plate4/O20_s1_w,,U2OS-05-4
19894,19894,U2OS-05/Plate4/O21_s1_w,,U2OS-05-4
19895,19895,U2OS-05/Plate4/O22_s1_w,,U2OS-05-4


In [96]:
force_group = np.load("force_group.npy", allow_pickle=True).tolist()

In [97]:
test1_df["grp"] = test1_df.pname.apply(lambda x:force_group[x])

In [110]:
test1_ds = actDs(test1_df,test_bc)
t1 = Trainer(test1_ds,shuffle=False, batch_size=1)

In [111]:
allpred1 = []

In [112]:
@t1.step_train
def action(batch):
    model.eval()
    if batch.i ==0:
        model.eval()
#     opt.zero_grad()
    x,y1,y2 = batch.data
    x  =x[0].float(); y1 = y1[0]
    if CUDA:
        x = x.cuda()
        y1 = y1.cuda()
        y2 = y2.cuda()
        
    y1_ = model(x)
    allpred1.append(y1_.cpu().data.numpy())
    return {"maxval":y1_.argmax(dim=-1)[0].item()}

In [113]:
t1.train(1)

HBox(children=(IntProgress(value=0, max=311), HTML(value='')))

  input = module(input)





In [120]:
np.save("shortpred1.npy",np.concatenate(allpred1, axis=0))