In [2]:
import torch
import torch.nn as nn
import os
import sys
import numpy as np
import argparse
import yaml
import pandas as pd
from matplotlib import pyplot as plt
sys.path.append('../dsc/')
sys.path.append("../..")
sys.path.append(os.getcwd())
from dsc_model import DSCModel
from bert2bert import Bert2BertSynCtrl
from transformers import BertConfig
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
#Training synthetic experiments on topk-donors
N_array = [26,51,76,101]
topk_array = [5,10]
for N in N_array:
    
    for topk in topk_array:

        datapath = f'../datasets/synthetic_data_N_{N}/'
        config_path = f'../exp_configs/synthetic_N_{N}/config.yaml'
        
        config = yaml.load(open(config_path,'r'),Loader=yaml.FullLoader)
        config['K'] = topk
        config['seq_range'] = topk+1
        op_path = f'../logs_dir/synthetic_N_{N}_topk_{topk}/'
        weights = np.load(f'../datasets/synthetic_data_N_{N}_1/weights.npy')
        random_seed = 0
        target_index = 0
        interv_time = 1600
        lowrank = False
        device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
    
        classes = None
        config_model = BertConfig(hidden_size = config['hidden_size'],
                                num_hidden_layers = config['n_layers'],
                                num_attention_heads = config['n_heads'],
                                intermediate_size = 4*config['hidden_size'],
                                vocab_size = 0,
                                max_position_embeddings = 0,
                                output_hidden_states = True,
                                )

        config_model.add_syn_ctrl_config(K=topk,
                                        pre_int_len=config['pre_int_len'],
                                        post_int_len=config['post_int_len'],
                                        feature_dim=config['feature_dim'],
                                        time_range=config['time_range'],
                                        seq_range=topk,
                                        cont_dim=config['cont_dim'],
                                        discrete_dim=config['discrete_dim'],
                                        classes = classes)
        model = Bert2BertSynCtrl(config_model, random_seed)
        model = model.to(device)
        dscmodel = DSCModel(model,
                            config,
                            op_path,
                            target_index,
                            random_seed,
                            datapath,
                            device,
                            topk = topk,
                            weights = weights,
                            lowrank = False,
                            classes=None)

    
        dscmodel.fit(interv_time,pretrain=True)

Pretraining model on donor units
Iteration:0	Loss_mean:0.6096203327178955	Loss_std:0.0
Iteration:100	Loss_mean:1.4476946642994881	Loss_std:0.6010017085717663
Iteration:200	Loss_mean:1.391789416372776	Loss_std:0.5903016633199265
Iteration:300	Loss_mean:0.9683531659841538	Loss_std:0.4764521487225832
Iteration:400	Loss_mean:0.7562964338064194	Loss_std:0.47123512357208386
Iteration:500	Loss_mean:0.2754240187630057	Loss_std:0.21223181793212828
Iteration:600	Loss_mean:0.14747840950265526	Loss_std:0.10060587207053474
Iteration:700	Loss_mean:0.123928152769804	Loss_std:0.07994539373022087
Iteration:800	Loss_mean:0.11159394984133542	Loss_std:0.06782196101409811
Iteration:900	Loss_mean:0.10051723930984735	Loss_std:0.05741748051635063
Iteration:1000	Loss_mean:0.09012435884214938	Loss_std:0.04984492805276641
Iteration:1100	Loss_mean:0.07019643181934952	Loss_std:0.04166823780107947
Iteration:1200	Loss_mean:0.061552679473534225	Loss_std:0.03823714668250587
Iteration:1300	Loss_mean:0.0509460100159049	

Iteration:900	Loss_mean:0.00859840096323751	Loss_std:0.004545122900502272
Iteration:1000	Loss_mean:0.008931447234936059	Loss_std:0.004946284552165954
Iteration:1100	Loss_mean:0.008221949038561433	Loss_std:0.005122049098075581
Iteration:1200	Loss_mean:0.0077627133438363674	Loss_std:0.0035227681404849586
Iteration:1300	Loss_mean:0.008475356745766476	Loss_std:0.007447833011814766
Iteration:1400	Loss_mean:0.009936735773226246	Loss_std:0.02488346670652784
Iteration:1500	Loss_mean:0.008164626676589251	Loss_std:0.010165636751076861
Iteration:1600	Loss_mean:0.008119498986052348	Loss_std:0.0038176649008743275
Iteration:1700	Loss_mean:0.0077754973201081155	Loss_std:0.0040470114815708835
Iteration:1800	Loss_mean:0.0072592085902579125	Loss_std:0.007117405131680456
Iteration:1900	Loss_mean:0.008195241040084511	Loss_std:0.004620283288071378
Iteration:2000	Loss_mean:0.006864034350728616	Loss_std:0.00313423626595389
Iteration:2100	Loss_mean:0.006822223423514515	Loss_std:0.0030800780828446964
Iteration

Iteration:6800	Loss_mean:0.015250356587348506	Loss_std:0.01317554235600395
Iteration:6900	Loss_mean:0.015018752787727862	Loss_std:0.012300185344114445
Iteration:7000	Loss_mean:0.013103573238477112	Loss_std:0.009195578214761097
Iteration:7100	Loss_mean:0.014679458980681374	Loss_std:0.01447865671650786
Iteration:7200	Loss_mean:0.016460441201925278	Loss_std:0.01350214539893964
Iteration:7300	Loss_mean:0.01269849992939271	Loss_std:0.010528180151847763
Iteration:7400	Loss_mean:0.016660198094323276	Loss_std:0.016961677013920762
Iteration:7500	Loss_mean:0.013734506468754262	Loss_std:0.018276162636921577
Iteration:7600	Loss_mean:0.01455277089960873	Loss_std:0.012585413858112113
Iteration:7700	Loss_mean:0.013052393742837011	Loss_std:0.008392146412887463
Iteration:7800	Loss_mean:0.01067867477191612	Loss_std:0.0077113553207168
Iteration:7900	Loss_mean:0.012382168815238402	Loss_std:0.008938102140748928
Iteration:8000	Loss_mean:0.011609600859228521	Loss_std:0.00709331587191071
Iteration:8100	Loss_m

Iteration:2700	Loss_mean:0.022498438586480915	Loss_std:0.018005038078220917
Iteration:2800	Loss_mean:0.027327046259306372	Loss_std:0.024559235142668195
Iteration:2900	Loss_mean:0.020066823174711316	Loss_std:0.015624134916805824
Iteration:3000	Loss_mean:0.02397716057021171	Loss_std:0.020055039929272257
Iteration:3100	Loss_mean:0.02313180876430124	Loss_std:0.015795123291514956
Iteration:3200	Loss_mean:0.02361310645705089	Loss_std:0.01911630374836617
Iteration:3300	Loss_mean:0.023708489753771572	Loss_std:0.018064966799945414
Iteration:3400	Loss_mean:0.02879051140509546	Loss_std:0.023200850344984446
Iteration:3500	Loss_mean:0.017398278675973416	Loss_std:0.013035145468777343
Iteration:3600	Loss_mean:0.020412286252249032	Loss_std:0.015770799404340164
Iteration:3700	Loss_mean:0.01978436051402241	Loss_std:0.014451871730449234
Iteration:3800	Loss_mean:0.022893168884329497	Loss_std:0.022117943170255935
Iteration:3900	Loss_mean:0.017006298657506705	Loss_std:0.01286309111080861
Iteration:4000	Loss

Iteration:3600	Loss_mean:0.0082670785789378	Loss_std:0.009843135114524665
Iteration:3700	Loss_mean:0.008091367814922706	Loss_std:0.00838483389560088
Iteration:3800	Loss_mean:0.007176592256873846	Loss_std:0.003490870770737293
Iteration:3900	Loss_mean:0.008525397122139112	Loss_std:0.00661972513667989
Iteration:4000	Loss_mean:0.007059677531942725	Loss_std:0.0028466058706493047
Iteration:4100	Loss_mean:0.007754878980340436	Loss_std:0.003415581253284872
Iteration:4200	Loss_mean:0.007327873674221337	Loss_std:0.005042739697290492
Iteration:4300	Loss_mean:0.007760205487720669	Loss_std:0.003568943960793951
Iteration:4400	Loss_mean:0.008440148750087246	Loss_std:0.010129437245517566
Iteration:4500	Loss_mean:0.007735413166228682	Loss_std:0.0036568704293865784
Iteration:4600	Loss_mean:0.00738289060886018	Loss_std:0.004804158142518972
Iteration:4700	Loss_mean:0.008599270017584785	Loss_std:0.006363283038471348
Iteration:4800	Loss_mean:0.00756592754740268	Loss_std:0.003708196851832251
Iteration:4900	L

Iteration:9500	Loss_mean:0.011624933283310384	Loss_std:0.006935805513450603
Iteration:9600	Loss_mean:0.010978024760261178	Loss_std:0.006816289771274294
Iteration:9700	Loss_mean:0.01114118386642076	Loss_std:0.007326678047067372
Iteration:9800	Loss_mean:0.01095197239657864	Loss_std:0.0074165162969275335
Iteration:9900	Loss_mean:0.009719572637695819	Loss_std:0.0051654095176136665
Modifying K
Fitting model on target unit
Iteration:0	Loss_mean:0.003322165459394455	Loss_std:0.0
Iteration:100	Loss_mean:0.009729792191646993	Loss_std:0.00635953637982047
Iteration:200	Loss_mean:0.011777466286439449	Loss_std:0.028858480342845105
Iteration:300	Loss_mean:0.008782308602239936	Loss_std:0.004434704246635683
Iteration:400	Loss_mean:0.00880246698623523	Loss_std:0.004564762594088058
Iteration:500	Loss_mean:0.008355568638071418	Loss_std:0.004394449994895426
Iteration:600	Loss_mean:0.00889773725764826	Loss_std:0.005066545714781054
Iteration:700	Loss_mean:0.010789025885751471	Loss_std:0.024035091572140156
I

Iteration:5400	Loss_mean:0.023653731590602547	Loss_std:0.01966791146718519
Iteration:5500	Loss_mean:0.021932823448441923	Loss_std:0.01755853137629873
Iteration:5600	Loss_mean:0.02097662601387128	Loss_std:0.015211096907270705
Iteration:5700	Loss_mean:0.02268236572155729	Loss_std:0.018800882739335258
Iteration:5800	Loss_mean:0.01752583357039839	Loss_std:0.012021209883970634
Iteration:5900	Loss_mean:0.020107481880113482	Loss_std:0.012835777993961166
Iteration:6000	Loss_mean:0.015833687684498728	Loss_std:0.010309411392522751
Iteration:6100	Loss_mean:0.017012863056734204	Loss_std:0.011110950268789601
Iteration:6200	Loss_mean:0.015843647731235252	Loss_std:0.014546081745530322
Iteration:6300	Loss_mean:0.017136239716783165	Loss_std:0.010676434414834079
Iteration:6400	Loss_mean:0.015151829802198336	Loss_std:0.011927234046966993
Iteration:6500	Loss_mean:0.014534557652659714	Loss_std:0.00979569048055013
Iteration:6600	Loss_mean:0.016675522703444585	Loss_std:0.012551963343277546
Iteration:6700	Los

Iteration:1300	Loss_mean:0.04406297867186368	Loss_std:0.040856610418864354
Iteration:1400	Loss_mean:0.03364916292950511	Loss_std:0.025188633643071882
Iteration:1500	Loss_mean:0.03079733667196706	Loss_std:0.02733257845889632
Iteration:1600	Loss_mean:0.03356646361295134	Loss_std:0.023947627502193283
Iteration:1700	Loss_mean:0.026380097959190607	Loss_std:0.023441321946633346
Iteration:1800	Loss_mean:0.02304957085521892	Loss_std:0.01772436032047717
Iteration:1900	Loss_mean:0.02640617543598637	Loss_std:0.020834733611547335
Iteration:2000	Loss_mean:0.01807836295571178	Loss_std:0.012818756264149807
Iteration:2100	Loss_mean:0.024176658825017513	Loss_std:0.020456240239169053
Iteration:2200	Loss_mean:0.01996487006545067	Loss_std:0.019425410005872938
Iteration:2300	Loss_mean:0.02125583649845794	Loss_std:0.016416663596231052
Iteration:2400	Loss_mean:0.01640748486854136	Loss_std:0.012775236131623041
Iteration:2500	Loss_mean:0.021148708730470388	Loss_std:0.017099674180073177
Iteration:2600	Loss_mean

Iteration:2200	Loss_mean:0.008100496914703399	Loss_std:0.011040779302225301
Iteration:2300	Loss_mean:0.008732994355959818	Loss_std:0.01120032009368454
Iteration:2400	Loss_mean:0.006949703777208925	Loss_std:0.0033057587536109673
Iteration:2500	Loss_mean:0.006893604061915539	Loss_std:0.003413912111207724
Iteration:2600	Loss_mean:0.00716515516396612	Loss_std:0.004049136677756442
Iteration:2700	Loss_mean:0.00817104545305483	Loss_std:0.015165988660294285
Iteration:2800	Loss_mean:0.008131788538303226	Loss_std:0.009304300973826342
Iteration:2900	Loss_mean:0.007394653874216601	Loss_std:0.004172644352526175
Iteration:3000	Loss_mean:0.006848157662898302	Loss_std:0.0029527272474256876
Iteration:3100	Loss_mean:0.006985919994767755	Loss_std:0.006147683490187983
Iteration:3200	Loss_mean:0.007535799704492092	Loss_std:0.0058220580664341956
Iteration:3300	Loss_mean:0.006292067226022482	Loss_std:0.00240334937984115
Iteration:3400	Loss_mean:0.006787593082990497	Loss_std:0.002936184157695972
Iteration:350

Iteration:8100	Loss_mean:0.012157962322235108	Loss_std:0.0076785372731163325
Iteration:8200	Loss_mean:0.010296113828662783	Loss_std:0.006657004659198936
Iteration:8300	Loss_mean:0.011174803697504104	Loss_std:0.006886978352002668
Iteration:8400	Loss_mean:0.010708624550607055	Loss_std:0.006220525352666797
Iteration:8500	Loss_mean:0.01060174219077453	Loss_std:0.005932937934229726
Iteration:8600	Loss_mean:0.010686724296538158	Loss_std:0.006020271539857939
Iteration:8700	Loss_mean:0.009926300384104253	Loss_std:0.005487900484004298
Iteration:8800	Loss_mean:0.00982009653467685	Loss_std:0.004593450776371689
Iteration:8900	Loss_mean:0.011553267971612513	Loss_std:0.007885229353486
Iteration:9000	Loss_mean:0.011700253523886204	Loss_std:0.007283646278182296
Iteration:9100	Loss_mean:0.011486176354810595	Loss_std:0.00567482369467753
Iteration:9200	Loss_mean:0.011607739010360092	Loss_std:0.006795513566683619
Iteration:9300	Loss_mean:0.012618903145194054	Loss_std:0.009464684651563596
Iteration:9400	Lo

Iteration:4000	Loss_mean:0.024324943935498597	Loss_std:0.023954698024512056
Iteration:4100	Loss_mean:0.02196585826575756	Loss_std:0.021022907457574885
Iteration:4200	Loss_mean:0.02769124072510749	Loss_std:0.026165701476728326
Iteration:4300	Loss_mean:0.022790176330599935	Loss_std:0.0204566425985192
Iteration:4400	Loss_mean:0.02150439165532589	Loss_std:0.01897286572328346
Iteration:4500	Loss_mean:0.02383730928879231	Loss_std:0.019877689424823833
Iteration:4600	Loss_mean:0.021393812715541572	Loss_std:0.018762238604299538
Iteration:4700	Loss_mean:0.02158021560870111	Loss_std:0.019749728964518057
Iteration:4800	Loss_mean:0.02155598153825849	Loss_std:0.02298673830005887
Iteration:4900	Loss_mean:0.02540647909976542	Loss_std:0.020217291293499872
Iteration:5000	Loss_mean:0.022506068625953047	Loss_std:0.020512743657062557
Iteration:5100	Loss_mean:0.020010658092796802	Loss_std:0.01614832524198738
Iteration:5200	Loss_mean:0.020159352187765763	Loss_std:0.017199476987062753
Iteration:5300	Loss_mean

Iteration:4800	Loss_mean:0.0074538668198511	Loss_std:0.003369678897280111
Iteration:4900	Loss_mean:0.0077185405313503	Loss_std:0.0032010356634940246


In [2]:
#Training synthetic experiments on topk-donors
N_array = [26,51,76,101]
topk_array = [5,10]
for N in N_array:
    
    for topk in topk_array:
        
        for id in range(3):

            datapath = f'../datasets/synthetic_data_N_{N}_{id+1}/'
            config_path = f'../exp_configs/synthetic_N_{N}/config.yaml'
            config = yaml.load(open(config_path,'r'),Loader=yaml.FullLoader)
            config['K'] = topk
            config['seq_range'] = topk+1
            modelpath = f'../logs_dir/synthetic_N_{N}_topk_{topk}/finetune/model.pth'
            op_path = f'../logs_dir/synthetic_N_{N}_topk_{topk}/'
            weights = np.load(f'../datasets/synthetic_data_N_{N}_1/weights.npy')
            random_seed = 0
            target_index = 0
            interv_time = 1600
            lowrank = False
            device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
            

            classes = None
            config_model = BertConfig(hidden_size = config['hidden_size'],
                                    num_hidden_layers = config['n_layers'],
                                    num_attention_heads = config['n_heads'],
                                    intermediate_size = 4*config['hidden_size'],
                                    vocab_size = 0,
                                    max_position_embeddings = 0,
                                    output_hidden_states = True,
                                    )

            config_model.add_syn_ctrl_config(K=topk,
                                            pre_int_len=config['pre_int_len'],
                                            post_int_len=config['post_int_len'],
                                            feature_dim=config['feature_dim'],
                                            time_range=config['time_range'],
                                            seq_range=topk,
                                            cont_dim=config['cont_dim'],
                                            discrete_dim=config['discrete_dim'],
                                            classes = classes)
            model = Bert2BertSynCtrl(config_model, random_seed)
            model = model.to(device)
            dscmodel = DSCModel(model,
                                config,
                                op_path,
                                target_index,
                                random_seed,
                                datapath,
                                device,
                                topk = topk,
                                weights = weights,
                        
                                lowrank = False,
                                classes=None)
            dscmodel.load_model_from_checkpoint(modelpath)


            op_pred = dscmodel.predict(interv_time)[interv_time:]

            np.save(datapath+'target.npy',op_pred)
            meanmatrix = np.load(datapath+'mean1.npy')
            test_mean = meanmatrix[0,interv_time:]
            err = np.sqrt(np.mean((op_pred- test_mean)**2))
            print(f'For N = {N} topk = {topk}, id = {id+1} err is {err}')

Modifying K
For N = 26 topk = 5, id = 1 err is 1.0050761699676514
Modifying K
For N = 26 topk = 5, id = 2 err is 0.7290533781051636
Modifying K
For N = 26 topk = 5, id = 3 err is 0.7088204026222229
Modifying K
For N = 26 topk = 10, id = 1 err is 1.0192203521728516
Modifying K
For N = 26 topk = 10, id = 2 err is 0.7661535143852234
Modifying K
For N = 26 topk = 10, id = 3 err is 0.7655360698699951
Modifying K
For N = 51 topk = 5, id = 1 err is 0.8421191573143005
Modifying K
For N = 51 topk = 5, id = 2 err is 0.9415892958641052
Modifying K
For N = 51 topk = 5, id = 3 err is 0.7476657032966614
Modifying K
For N = 51 topk = 10, id = 1 err is 1.042747139930725
Modifying K
For N = 51 topk = 10, id = 2 err is 1.1387977600097656
Modifying K
For N = 51 topk = 10, id = 3 err is 0.9446754455566406
Modifying K
For N = 76 topk = 5, id = 1 err is 1.3713016510009766
Modifying K
For N = 76 topk = 5, id = 2 err is 1.218398928642273
Modifying K
For N = 76 topk = 5, id = 3 err is 1.1971648931503296
Modify