In [1]:
from data_preparation import *

In [2]:
import torch
import numpy as np
import torch.optim as optim
from sklearn.model_selection import KFold
from GSR_Net.preprocessing import *
from GSR_Net.model import *
from GSR_Net.train import *
import argparse

seed_value = 42
np.random.seed(seed_value)


In [3]:
path_to_data = '/vol/bitbucket/km2120/DGL_Project/DGL24-Group-Project/data'

x_train, x_test, y_train = load_data_tensor(path_to_data)

In [21]:
X, Y = x_train, y_train

In [26]:
y_train.max()

tensor(0.9999)

In [12]:
class ModelArgs:  
    def __init__(self, epochs, lr, splits, lmbda, lr_dim, hr_dim, hidden_dim,
                 padding):
        self.epochs = epochs
        self.lr = lr
        self.splits = splits
        self.lmbda = lmbda
        self.lr_dim = lr_dim
        self.hr_dim = hr_dim
        self.hidden_dim = hidden_dim
        self.padding = padding

args = ModelArgs(
    epochs=200,
    lr=0.0001,
    splits=3,
    lmbda=16,
    lr_dim=160,
    #hr_dim=320,
    #hidden_dim=320,
    hr_dim=320,
    hidden_dim=320,
    padding=26,
)

In [11]:
cv = KFold(n_splits=args.splits, random_state=42, shuffle=True)
ks = [0.9, 0.7, 0.6, 0.5]


In [13]:
ks = [0.9, 0.7, 0.6, 0.5]
model = GSRNet(ks, args)
print(model)

GSRNet(
  (layer): GSRLayer()
  (net): GraphUnet(
    (start_gcn): GCN(
      (proj): Linear(in_features=160, out_features=320, bias=True)
      (drop): Dropout(p=0, inplace=False)
    )
    (bottom_gcn): GCN(
      (proj): Linear(in_features=320, out_features=320, bias=True)
      (drop): Dropout(p=0, inplace=False)
    )
    (end_gcn): GCN(
      (proj): Linear(in_features=640, out_features=320, bias=True)
      (drop): Dropout(p=0, inplace=False)
    )
  )
  (gc1): GraphConvolution()
  (gc2): GraphConvolution()
)


In [13]:
Y.size()

torch.Size([167, 268, 268])

In [14]:
def train_model():
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    for train_index, test_index in cv.split(X):
        subjects_adj, test_adj, subjects_ground_truth, test_ground_truth = X[
            train_index], X[test_index], Y[train_index], Y[test_index]
        train(model, optimizer, subjects_adj, subjects_ground_truth, args)
        test(model, test_adj, test_ground_truth, args)
    

In [16]:
from evaluations import *

In [14]:
saved_state_dict = torch.load('/vol/bitbucket/km2120/DGL_Project/DGL24-Group-Project/GSR_models/vanilla_model.pth')
model.load_state_dict(saved_state_dict, strict=True)

<All keys matched successfully>

In [15]:
preds_list = []
model.eval()
with torch.no_grad():
    for test_data in x_test:
        pred, _, _, _ = model(test_data)
        pred = unpad(pred, args.padding)
        preds_list.append(pred)
        
pred_tensor = torch.stack(preds_list)

In [20]:
pred_tensor[0].max()

tensor(1.7666)

In [19]:
pred_tensor.size()

torch.Size([112, 268, 268])

In [20]:
evaluate(pred_tensor, Y)

0

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

MAE:  0.17028992
PCC:  0.4419352749875999
Jensen-Shannon Distance:  0.3305826901548274
Average MAE betweenness centrality: 0.022196793777044616
Average MAE eigenvector centrality: 0.01697010370854288
Average MAE PageRank centrality: 0.0007530632349348782


In [15]:
pred = model(x_test[0])[0]

In [16]:
pred.size()

torch.Size([320, 320])

In [43]:
from data_preparation import *

In [60]:
def find_preds_and_convert_to_submission(file_name, model=model, test_set=x_test):
    preds_list = []
    with torch.no_grad():
        for test_data in test_set:
            pred, _, _, _ = model(test_data)
            pred = unpad(pred, args.padding)
            preds_list.append(pred)
    pred_tensor = torch.stack(preds_list).cpu().numpy()
    generate_submission_file(pred_tensor, f'/vol/bitbucket/km2120/DGL_Project/DGL24-Group-Project/submission_files/{file_name}')        