In [1]:
import torch
from torch.cuda import random
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from os import listdir
from os.path import isfile, join
from tracin.tracin_batched import save_tracin_checkpoint, load_tracin_checkpoint,  approximate_tracin_batched
import pandas as pd
from LSTM_clean.model import LSTM
import numpy as np
import re
from statistics import mean
import scipy.stats as stats
import pandas as pd
from sklearn.utils import shuffle
from copy import deepcopy

# Global Parameters

In [2]:
OUTPUT_SIZE = 1743

# Get Most important Checkpoint

In [3]:
curr_dir = os.getcwd()
path = curr_dir + "/checkpoints_subset/"
checkpoints = []
with os.scandir(path) as listOfEntries:
    for entry in listOfEntries:
        # print all entries that are files
        if entry.is_file():
            checkpoints.append(os.path.join(path,entry.name))
last_checkpoint_epoch = max([re.sub('[^0-9]','', a)[2:] for a in checkpoints])
last_checkpoint = sorted(checkpoints)[-1][:-5] + str(last_checkpoint_epoch) + ".pt"

# Set Up Devices

In [4]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]='6'

cpu_device = torch.device("cpu")
print("CPU Device is ", cpu_device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device is ", device)

CPU Device is  cpu
device is  cuda


# Load In Data

In [5]:
train = np.load(os.path.join(os.getcwd(), "data/twitch_sequence/train.data"), allow_pickle=True)

In [6]:
model = LSTM(
        input_size=128,
        output_size=OUTPUT_SIZE,
        hidden_dim=64,
        n_layers=1,
        device=cpu_device,
    )

In [7]:
train_num = len(train)
train_labels = []

for i in range(train_num):
    train_labels.append(train[i][1])
train = [train[i][0] for i in range(train_num)]


# Subset Data
Don't need to do if you want to run full experiment (takes a long time to run ~24 hours)

In [8]:
train, train_labels = shuffle(train, train_labels, random_state=201)
train = train[:3000]
train_labels = train_labels[:3000]

# Deepcopy to prevent overwriting

In [9]:
train_copy = deepcopy(train)
train_labels_copy = deepcopy(train_labels)

# Run Experiments

In [10]:
si = []
ri =[]

for i in range(20):
    print("___________________________________________________________________________________")
    print(f"Iteration {i}")
    train_random, train_labels_random = shuffle(train_copy, train_labels_copy, random_state=i)
    self_influence = approximate_tracin_batched(LSTM, sources=train_random, targets=train_random, source_labels=train_labels_random, target_labels=train_labels_random, optimizer="SGD", paths=checkpoints, batch_size=2048, num_items=OUTPUT_SIZE, device=device)
    print(f"Self influence is: {self_influence}")
    si.append(self_influence)
    train_random, train_labels_random = shuffle(train_copy, train_labels_copy, random_state=i)
    rs = approximate_tracin_batched(LSTM, sources=train, targets=train_random, source_labels=train_labels, target_labels=train_labels_random, optimizer="SGD", paths=checkpoints, batch_size=2048, num_items=OUTPUT_SIZE, device=device)
    print(f"Random Sample {i} Influence is {rs}")
    ri.append(rs)

___________________________________________________________________________________
Iteration 0
In checkpoint number: 0
Total time for checkpoint 0 : 24.67806100845337
In checkpoint number: 1
Total time for checkpoint 1 : 12.459105730056763
In checkpoint number: 2
Total time for checkpoint 2 : 11.063999652862549
In checkpoint number: 3
Total time for checkpoint 3 : 12.087036848068237
In checkpoint number: 4
Total time for checkpoint 4 : 11.126459836959839
In checkpoint number: 5
Total time for checkpoint 5 : 14.599980592727661
In checkpoint number: 6
Total time for checkpoint 6 : 10.92798137664795
In checkpoint number: 7
Total time for checkpoint 7 : 10.986692667007446
In checkpoint number: 8
Total time for checkpoint 8 : 11.02364706993103
In checkpoint number: 9
Total time for checkpoint 9 : 7.775242567062378
In checkpoint number: 10
Total time for checkpoint 10 : 9.68164348602295
In checkpoint number: 11
Total time for checkpoint 11 : 10.861163854598999
In checkpoint number: 12
Total

Total time for checkpoint 19 : 13.25405216217041
In checkpoint number: 20
Total time for checkpoint 20 : 13.242337465286255
Total time taken is 300.5848665237427
Self influence is: 0.08289601653814316
In checkpoint number: 0
Total time for checkpoint 0 : 1.4141178131103516
In checkpoint number: 1
Total time for checkpoint 1 : 1.6256227493286133
In checkpoint number: 2
Total time for checkpoint 2 : 1.3595144748687744
In checkpoint number: 3
Total time for checkpoint 3 : 1.6997575759887695
In checkpoint number: 4
Total time for checkpoint 4 : 2.9019246101379395
In checkpoint number: 5
Total time for checkpoint 5 : 4.243244647979736
In checkpoint number: 6
Total time for checkpoint 6 : 6.272408723831177
In checkpoint number: 7
Total time for checkpoint 7 : 6.016612529754639
In checkpoint number: 8
Total time for checkpoint 8 : 1.3535501956939697
In checkpoint number: 9
Total time for checkpoint 9 : 5.641200542449951
In checkpoint number: 10
Total time for checkpoint 10 : 9.005222082138062

Total time for checkpoint 18 : 27.100098848342896
In checkpoint number: 19
Total time for checkpoint 19 : 21.389662742614746
In checkpoint number: 20
Total time for checkpoint 20 : 20.66663098335266
Total time taken is 407.93447184562683
Random Sample 4 Influence is 0.04387207701802254
___________________________________________________________________________________
Iteration 5
In checkpoint number: 0
Total time for checkpoint 0 : 21.24635672569275
In checkpoint number: 1
Total time for checkpoint 1 : 24.93926477432251
In checkpoint number: 2
Total time for checkpoint 2 : 22.850805044174194
In checkpoint number: 3
Total time for checkpoint 3 : 23.511293411254883
In checkpoint number: 4
Total time for checkpoint 4 : 25.183452129364014
In checkpoint number: 5
Total time for checkpoint 5 : 31.177069425582886
In checkpoint number: 6
Total time for checkpoint 6 : 26.279199838638306
In checkpoint number: 7
Total time for checkpoint 7 : 24.548200845718384
In checkpoint number: 8
Total time 

Total time for checkpoint 16 : 33.74471974372864
In checkpoint number: 17
Total time for checkpoint 17 : 44.03439545631409
In checkpoint number: 18
Total time for checkpoint 18 : 53.45825457572937
In checkpoint number: 19
Total time for checkpoint 19 : 42.35509920120239
In checkpoint number: 20
Total time for checkpoint 20 : 45.92166543006897
Total time taken is 802.3150324821472
Self influence is: 0.08344727754592896
In checkpoint number: 0
Total time for checkpoint 0 : 44.39928460121155
In checkpoint number: 1
Total time for checkpoint 1 : 41.55577063560486
In checkpoint number: 2
Total time for checkpoint 2 : 45.42960810661316
In checkpoint number: 3
Total time for checkpoint 3 : 42.39766335487366
In checkpoint number: 4
Total time for checkpoint 4 : 56.70457100868225
In checkpoint number: 5
Total time for checkpoint 5 : 41.27765703201294
In checkpoint number: 6
Total time for checkpoint 6 : 38.58808088302612
In checkpoint number: 7
Total time for checkpoint 7 : 12.246232271194458
I

Total time for checkpoint 15 : 31.217378616333008
In checkpoint number: 16
Total time for checkpoint 16 : 39.975196838378906
In checkpoint number: 17
Total time for checkpoint 17 : 9.138544797897339
In checkpoint number: 18
Total time for checkpoint 18 : 5.90505051612854
In checkpoint number: 19
Total time for checkpoint 19 : 27.147632360458374
In checkpoint number: 20
Total time for checkpoint 20 : 25.670841455459595
Total time taken is 593.3387336730957
Random Sample 9 Influence is 0.044172290712594986
___________________________________________________________________________________
Iteration 10
In checkpoint number: 0
Total time for checkpoint 0 : 37.51439046859741
In checkpoint number: 1
Total time for checkpoint 1 : 39.42399525642395
In checkpoint number: 2
Total time for checkpoint 2 : 38.284584283828735
In checkpoint number: 3
Total time for checkpoint 3 : 36.51716923713684
In checkpoint number: 4
Total time for checkpoint 4 : 10.271988153457642
In checkpoint number: 5
Total t

Total time for checkpoint 13 : 37.42380475997925
In checkpoint number: 14
Total time for checkpoint 14 : 38.89318490028381
In checkpoint number: 15
Total time for checkpoint 15 : 38.71201276779175
In checkpoint number: 16
Total time for checkpoint 16 : 38.417989015579224
In checkpoint number: 17
Total time for checkpoint 17 : 37.53197908401489
In checkpoint number: 18
Total time for checkpoint 18 : 44.14405345916748
In checkpoint number: 19
Total time for checkpoint 19 : 29.619477033615112
In checkpoint number: 20
Total time for checkpoint 20 : 14.162635803222656
Total time taken is 713.411257982254
Self influence is: 0.07790303230285645
In checkpoint number: 0
Total time for checkpoint 0 : 5.595665693283081
In checkpoint number: 1
Total time for checkpoint 1 : 16.284074544906616
In checkpoint number: 2
Total time for checkpoint 2 : 26.638946771621704
In checkpoint number: 3
Total time for checkpoint 3 : 41.34564971923828
In checkpoint number: 4
Total time for checkpoint 4 : 43.8219594

Total time for checkpoint 12 : 39.234631061553955
In checkpoint number: 13
Total time for checkpoint 13 : 15.282129049301147
In checkpoint number: 14
Total time for checkpoint 14 : 10.758545398712158
In checkpoint number: 15
Total time for checkpoint 15 : 13.56417989730835
In checkpoint number: 16
Total time for checkpoint 16 : 21.25850749015808
In checkpoint number: 17
Total time for checkpoint 17 : 26.18394112586975
In checkpoint number: 18
Total time for checkpoint 18 : 44.03061652183533
In checkpoint number: 19
Total time for checkpoint 19 : 39.14077162742615
In checkpoint number: 20
Total time for checkpoint 20 : 40.407572507858276
Total time taken is 747.6819772720337
Random Sample 14 Influence is 0.04501301050186157
___________________________________________________________________________________
Iteration 15
In checkpoint number: 0
Total time for checkpoint 0 : 36.33926558494568
In checkpoint number: 1
Total time for checkpoint 1 : 39.38518023490906
In checkpoint number: 2
To

Total time for checkpoint 10 : 11.888633966445923
In checkpoint number: 11
Total time for checkpoint 11 : 32.48818922042847
In checkpoint number: 12
Total time for checkpoint 12 : 3.000398874282837
In checkpoint number: 13
Total time for checkpoint 13 : 36.45073223114014
In checkpoint number: 14
Total time for checkpoint 14 : 39.89703440666199
In checkpoint number: 15
Total time for checkpoint 15 : 37.954222202301025
In checkpoint number: 16
Total time for checkpoint 16 : 33.5154926776886
In checkpoint number: 17
Total time for checkpoint 17 : 34.83865451812744
In checkpoint number: 18
Total time for checkpoint 18 : 21.55147361755371
In checkpoint number: 19
Total time for checkpoint 19 : 35.99519228935242
In checkpoint number: 20
Total time for checkpoint 20 : 41.74561047554016
Total time taken is 742.122927904129
Self influence is: 0.08014470338821411
In checkpoint number: 0
Total time for checkpoint 0 : 40.83921957015991
In checkpoint number: 1
Total time for checkpoint 1 : 44.62824

Total time for checkpoint 9 : 32.77412557601929
In checkpoint number: 10
Total time for checkpoint 10 : 40.891037940979004
In checkpoint number: 11
Total time for checkpoint 11 : 42.48768925666809
In checkpoint number: 12
Total time for checkpoint 12 : 43.69516205787659
In checkpoint number: 13
Total time for checkpoint 13 : 42.05306315422058
In checkpoint number: 14
Total time for checkpoint 14 : 44.05160403251648
In checkpoint number: 15
Total time for checkpoint 15 : 36.61167812347412
In checkpoint number: 16
Total time for checkpoint 16 : 35.64877724647522
In checkpoint number: 17
Total time for checkpoint 17 : 50.46954298019409
In checkpoint number: 18
Total time for checkpoint 18 : 37.86741375923157
In checkpoint number: 19
Total time for checkpoint 19 : 36.73503518104553
In checkpoint number: 20
Total time for checkpoint 20 : 37.65563178062439
Total time taken is 745.1570444107056
Random Sample 19 Influence is 0.05317528545856476


In [11]:
ri = [float(i) for i in ri]
si = [float(j) for j in si]
print("Random influences are \n", ri)
print("Self influences are \n", si)

Random influences are 
 [0.045346006751060486, 0.047535426914691925, 0.04751908406615257, 0.0457342155277729, 0.04387207701802254, 0.05597292631864548, 0.05046650767326355, 0.04232533276081085, 0.04699711129069328, 0.044172290712594986, 0.04608522355556488, 0.04713398963212967, 0.04536144435405731, 0.04832974821329117, 0.04501301050186157, 0.042381975799798965, 0.051136333495378494, 0.041778214275836945, 0.05042362958192825, 0.05317528545856476]
Self influences are 
 [0.0804448276758194, 0.07881506532430649, 0.08289601653814316, 0.07580509036779404, 0.07611887156963348, 0.08564634621143341, 0.08174218237400055, 0.08344727754592896, 0.08448409289121628, 0.08004724234342575, 0.0788716971874237, 0.07925650477409363, 0.07790303230285645, 0.07972077280282974, 0.07766684144735336, 0.08244019746780396, 0.0861254334449768, 0.08014470338821411, 0.08581280708312988, 0.0826517716050148]


# Perform Statistical Tests

In [12]:
print(f"Difference in population is {stats.ttest_ind(a=np.array(si), b=np.array(ri), equal_var=False)}")

Difference in population is Ttest_indResult(statistic=31.302050349891296, pvalue=3.4555497662919746e-28)
