To use pytorch geometric temporal, make sure you have torch 1.9.0 installed (uninstall 1.10.0 before).

In [1]:
import torch
print(torch.__version__)

1.9.0+cpu


In [2]:
import torch
import numpy as np
import pandas as pd

In [3]:
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html
!pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html
!pip install torch-geometric
!pip install torch-geometric-temporal

Looking in links: https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html
Looking in links: https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html
Looking in links: https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html
Looking in links: https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html


In [4]:
from sklearn.preprocessing import normalize
from sklearn.preprocessing import MinMaxScaler

def transform_and_split(data):
    # Normalize node features and transform data type
    data.x = normalize(data.x, axis=1, norm='max')
    data.x = torch.from_numpy(data.x).to(torch.float64)
    data.y = data.y.apply_(lambda x:  1 if (x > 0) else 0) # Change y into {0, 1} for binary classification
    data.y = data.y.to(torch.float64)    
    data.edge_attr = data.edge_attr.to(torch.double)


    # Split into train/test set
#    split = nodeSplit(split="train_rest", num_splits = 1, num_val = 0.0, num_test= 0.2)
#    masked_data = split(data)

#    print("Training samples:", torch.sum(masked_data.train_mask).item())
#    print("Validation samples:", torch.sum(masked_data.val_mask ).item())
#    print("Test samples:", torch.sum(masked_data.test_mask ).item())
    print_basic_info(data)
    return data

In [5]:
def print_basic_info(data):
    print()
    print(data)
    print('===========================================================================================================')

    print(f'Number of nodes: {data.num_nodes}')
    print(f'Number of edges: {data.num_edges}')
    print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
    print(f'Has isolated nodes: {data.has_isolated_nodes()}')
    print(f'Has self-loops: {data.has_self_loops()}')
    print(f'Is undirected: {data.is_undirected()}')

### Get and split data

In [6]:
path = "../data/processed/twitter/2018_q1.pt" # Customize...
dataset = torch.load(path)
data = dataset[0]
transformed_data = transform_and_split(data)


Data(x=[29, 61], edge_index=[2, 400], edge_attr=[400], y=[29])
Number of nodes: 29
Number of edges: 400
Average node degree: 13.79
Has isolated nodes: False
Has self-loops: True
Is undirected: True


In [7]:
df = pd.read_csv('../data/raw/news/news_data_weekly.csv')

In [8]:
df = df.dropna()

In [9]:
df = df.sort_values('Date')

In [10]:
df

Unnamed: 0,Date,url,texts
2746,2016-10-08,https://www.nytimes.com/2016/10/04/opinion/the...,Donald Trump is a thug. He’s a thug who talks...
33,2016-10-08,https://www.nytimes.com/2016/10/08/us/politics...,In lucrative paid speeches that Hillary Clint...
32,2016-10-08,https://www.nytimes.com/2016/10/07/technology/...,"SAN FRANCISCO — Marc Benioff, the founder and..."
31,2016-10-08,https://www.nytimes.com/2016/10/05/business/pr...,Prepaid debit cards are a financial lifeline ...
30,2016-10-08,https://www.nytimes.com/2016/10/04/world/ameri...,RIO DE JANEIRO — It was not a banner day for ...
...,...,...,...
1518,2021-10-02,https://www.nytimes.com/2021/09/30/sports/socc...,Looking to expand its global footprint beyond...
1517,2021-10-02,https://www.nytimes.com/2021/10/01/business/cr...,Despite the popularity of mobile apps promisi...
1516,2021-10-02,https://www.nytimes.com/2021/10/02/your-money/...,Introducing your child to the real-world use ...
2736,2021-10-02,https://www.nytimes.com/2021/09/26/fashion/wat...,Like their counterparts in industries such as...


In [11]:
df['Date'] = pd.to_datetime(df['Date'], format='%Y-%m-%d')

In [12]:
df_list = [part for _, part in df.groupby(pd.Grouper(key='Date', freq='W-MON'))]

In [13]:
df_list[0]

Unnamed: 0,Date,url,texts
2746,2016-10-08,https://www.nytimes.com/2016/10/04/opinion/the...,Donald Trump is a thug. He’s a thug who talks...
33,2016-10-08,https://www.nytimes.com/2016/10/08/us/politics...,In lucrative paid speeches that Hillary Clint...
32,2016-10-08,https://www.nytimes.com/2016/10/07/technology/...,"SAN FRANCISCO — Marc Benioff, the founder and..."
31,2016-10-08,https://www.nytimes.com/2016/10/05/business/pr...,Prepaid debit cards are a financial lifeline ...
30,2016-10-08,https://www.nytimes.com/2016/10/04/world/ameri...,RIO DE JANEIRO — It was not a banner day for ...
36,2016-10-08,https://www.nytimes.com/2016/10/07/business/de...,WASHINGTON — Nearly five years after Jon S. C...
37,2016-10-08,https://www.nytimes.com/2016/10/08/business/in...,LONDON — As Europe has grappled with the trau...
38,2016-10-08,https://www.nytimes.com/2016/10/08/world/europ...,LONDON — For those blithely inclined toward t...
34,2016-10-08,https://www.nytimes.com/2016/10/04/business/de...,The Janus Capital Group and the Henderson Gro...
2747,2016-10-08,https://www.nytimes.com/2016/10/09/world/middl...,TEHRAN — Rushing for a plane to Tehran becaus...


In [14]:
import yaml

with open('../configs/dow_jones.yaml') as f:
    
    data = yaml.load(f, Loader=yaml.FullLoader)
    print(data)

{'companies': [{'wba': {'alias': ['$wba', 'wba', 'walgreen boots alliance inc', 'walgreen boots alliance', 'walgreenbootsalliance']}}, {'v': {'alias': ['$v', 'v', 'visa inc class a', 'visa']}}, {'crm': {'alias': ['$crm', 'crm', 'salesforce.com inc', 'salesforce']}}, {'cvx': {'alias': ['$cvx', 'cvx', 'chevron corp', 'chevron']}}, {'pg': {'alias': ['$pg', 'pg', 'procter & gamble', 'procter&gamble']}}, {'vz': {'alias': ['$vz', 'vz', 'verizon communications inc', 'verizon']}}, {'wmt': {'alias': ['$wmt', 'wmt', 'walmart stores inc', 'walmart stores', 'walmart']}}, {'unh': {'alias': ['$unh', 'unh', 'unitedhealth group inc', 'unitedhealth group', 'unitedhealthgroup']}}, {'trv': {'alias': ['$trv', 'trv', 'travelers companies inc', 'travelers companies', 'travelers', 'travelerscompanies']}}, {'mcd': {'alias': ['$mcd', 'mcd', 'mcdonalds corp', 'mcdonalds']}}, {'mmm': {'alias': ['$mmm', 'mmm', '3m', '3m']}}, {'nke': {'alias': ['$nke', 'nke', 'nike inc class b', 'nike']}}, {'mrk': {'alias': ['$mrk

In [15]:
def get_matrix(df):
    companies = [list(com.keys())[0] for com in data['companies']]
    alias = list(map(lambda x: list(x.items())[0][1]["alias"], data['companies']))
    res = pd.DataFrame(0, index=companies, columns=companies)
    for company1, search_items1 in zip(companies, alias):
        for company2, search_items2 in zip(companies, alias):
            if company1 != company2:
                search_items = search_items1 + search_items2
            else:
                search_items = search_items1
            pat = "|".join(search_items)
            res[company1][company2] += df.texts.str.contains(
                pat
            ).sum()
    return res.values

In [16]:
mat_list = [get_matrix(df) for df in df_list]

In [17]:
len(mat_list)

261

In [18]:
week = 0
for i in range(261):
    np.save('../data/raw/news/week_'+str(week)+'.npy', mat_list[i])
    week += 1
        

In [19]:
import os
stock_df = pd.read_csv(
            os.path.join('../data/raw',"stock","raw.csv"),
            usecols=["ticker_symbol", "Date", "Close"],
            parse_dates=["Date"],
        )

In [20]:
stock_df

Unnamed: 0,Date,Close,ticker_symbol
0,2016-10-03,28.129999,aapl
1,2016-10-04,28.250000,aapl
2,2016-10-05,28.262501,aapl
3,2016-10-06,28.472500,aapl
4,2016-10-07,28.514999,aapl
...,...,...,...
36506,2021-09-27,142.250000,wmt
36507,2021-09-28,140.500000,wmt
36508,2021-09-29,140.440002,wmt
36509,2021-09-30,139.380005,wmt


In [25]:
from datetime import timedelta
from datetime import datetime
start = datetime.strptime('2016-10-02', "%Y-%m-%d")

In [26]:
date_list = [start]
cur = start
for i in range(261):
    cur = cur + timedelta(days=7)
    date_list.append(cur)

In [27]:
date_list

[datetime.datetime(2016, 10, 2, 0, 0),
 datetime.datetime(2016, 10, 9, 0, 0),
 datetime.datetime(2016, 10, 16, 0, 0),
 datetime.datetime(2016, 10, 23, 0, 0),
 datetime.datetime(2016, 10, 30, 0, 0),
 datetime.datetime(2016, 11, 6, 0, 0),
 datetime.datetime(2016, 11, 13, 0, 0),
 datetime.datetime(2016, 11, 20, 0, 0),
 datetime.datetime(2016, 11, 27, 0, 0),
 datetime.datetime(2016, 12, 4, 0, 0),
 datetime.datetime(2016, 12, 11, 0, 0),
 datetime.datetime(2016, 12, 18, 0, 0),
 datetime.datetime(2016, 12, 25, 0, 0),
 datetime.datetime(2017, 1, 1, 0, 0),
 datetime.datetime(2017, 1, 8, 0, 0),
 datetime.datetime(2017, 1, 15, 0, 0),
 datetime.datetime(2017, 1, 22, 0, 0),
 datetime.datetime(2017, 1, 29, 0, 0),
 datetime.datetime(2017, 2, 5, 0, 0),
 datetime.datetime(2017, 2, 12, 0, 0),
 datetime.datetime(2017, 2, 19, 0, 0),
 datetime.datetime(2017, 2, 26, 0, 0),
 datetime.datetime(2017, 3, 5, 0, 0),
 datetime.datetime(2017, 3, 12, 0, 0),
 datetime.datetime(2017, 3, 19, 0, 0),
 datetime.datetime(2

In [29]:
comp_emb = []
for fp in sorted(os.listdir("../data/raw/sec/")):
    full_path = os.path.join("../data/raw", "sec", fp)
    if fp.split(".")[-1]=='npy':
        comp_emb.append(torch.from_numpy(np.load(full_path)))
comp_emb = torch.stack(comp_emb)

In [30]:
X_y = []
for i in range(260):
    start_date = date_list[i]
    end_date = start_date+timedelta(days=6)
    next_start_date = start_date+timedelta(days=7)
    next_end_date = start_date+timedelta(days=13)
    ######################################################## 
    # prepare X (change this if you want to add SEC emb, etc.)
    ########################################################
    curr = stock_df[(stock_df.Date>=start_date) & (stock_df.Date<=end_date)]
    X = curr.pivot_table(
            index="Date", columns="ticker_symbol", values="Close"
        ).values.T
    X_tensor = torch.tensor(X)
    
    ########################################################
    # prepare y (change this if you want to change labels)
    ########################################################

    
    nxt = stock_df[(stock_df.Date>=next_start_date) & (stock_df.Date<=next_end_date)]
    y = nxt.pivot_table(
            index="Date", columns="ticker_symbol", values="Close"
        ).values.T
    y = (y.mean(1) - X.mean(1)) / X.mean(1)
    y_tensor = torch.tensor(y)
    X_y.append((X_tensor,y_tensor))

In [33]:
path = "../data/processed/twitter/"

In [34]:
quarter = ['2016_q4']
for i in range(2017, 2022):
    for j in range(1, 5):
        if i == 2021 and j == 4: break
        quarter.append(str(i)+'_q'+str(j))

In [35]:
quarter

['2016_q4',
 '2017_q1',
 '2017_q2',
 '2017_q3',
 '2017_q4',
 '2018_q1',
 '2018_q2',
 '2018_q3',
 '2018_q4',
 '2019_q1',
 '2019_q2',
 '2019_q3',
 '2019_q4',
 '2020_q1',
 '2020_q2',
 '2020_q3',
 '2020_q4',
 '2021_q1',
 '2021_q2',
 '2021_q3']

In [36]:
paths = []
for i in quarter:
    paths.append(path+i+'.pt')

In [37]:
paths

['../data/processed/twitter/2016_q4.pt',
 '../data/processed/twitter/2017_q1.pt',
 '../data/processed/twitter/2017_q2.pt',
 '../data/processed/twitter/2017_q3.pt',
 '../data/processed/twitter/2017_q4.pt',
 '../data/processed/twitter/2018_q1.pt',
 '../data/processed/twitter/2018_q2.pt',
 '../data/processed/twitter/2018_q3.pt',
 '../data/processed/twitter/2018_q4.pt',
 '../data/processed/twitter/2019_q1.pt',
 '../data/processed/twitter/2019_q2.pt',
 '../data/processed/twitter/2019_q3.pt',
 '../data/processed/twitter/2019_q4.pt',
 '../data/processed/twitter/2020_q1.pt',
 '../data/processed/twitter/2020_q2.pt',
 '../data/processed/twitter/2020_q3.pt',
 '../data/processed/twitter/2020_q4.pt',
 '../data/processed/twitter/2021_q1.pt',
 '../data/processed/twitter/2021_q2.pt',
 '../data/processed/twitter/2021_q3.pt']

In [38]:
data_list = []

In [39]:
for path in paths:
    dataset = torch.load(path)
    data = dataset[0]
    data_list.append(transform_and_split(data))


Data(x=[29, 63], edge_index=[2, 760], edge_attr=[760], y=[29])
Number of nodes: 29
Number of edges: 760
Average node degree: 26.21
Has isolated nodes: False
Has self-loops: True
Is undirected: True

Data(x=[29, 62], edge_index=[2, 312], edge_attr=[312], y=[29])
Number of nodes: 29
Number of edges: 312
Average node degree: 10.76
Has isolated nodes: False
Has self-loops: True
Is undirected: True

Data(x=[29, 63], edge_index=[2, 400], edge_attr=[400], y=[29])
Number of nodes: 29
Number of edges: 400
Average node degree: 13.79
Has isolated nodes: False
Has self-loops: True
Is undirected: True

Data(x=[29, 63], edge_index=[2, 552], edge_attr=[552], y=[29])
Number of nodes: 29
Number of edges: 552
Average node degree: 19.03
Has isolated nodes: False
Has self-loops: True
Is undirected: True

Data(x=[29, 63], edge_index=[2, 805], edge_attr=[805], y=[29])
Number of nodes: 29
Number of edges: 805
Average node degree: 27.76
Has isolated nodes: False
Has self-loops: True
Is undirected: True

Data

In [40]:
len(data_list)

20

In [41]:
data_list[1].x.shape

torch.Size([29, 62])

In [42]:
"""
edge_indices = [i.edge_index.double() for i in data_list]
edge_weights = [i.edge_attr.double() for i in data_list]
features = [i.x.double() for i in data_list]
targets = [i.y.double() for i in data_list]
"""
"""
edge_indices = [i.edge_index.cpu().detach().numpy() for i in data_list]
edge_weights = [i.edge_attr.cpu().detach().numpy() for i in data_list]
features = [i.x.cpu().detach().numpy() for i in data_list]
targets = [i.y.cpu().detach().numpy() for i in data_list]
"""
edge_indices = [i.edge_index.numpy() for i in data_list]
edge_weights = [i.edge_attr.numpy() for i in data_list]
features = [i.x.numpy() for i in data_list]
targets = [i.y.numpy() for i in data_list]

In [43]:
edge_indices

[array([[ 0,  0,  0, ..., 28, 28, 28],
        [ 2,  3,  4, ..., 26, 27, 28]], dtype=int64),
 array([[ 0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,
          2,  2,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,  4,
          4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
          4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,
          6,  7,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  9,  9,  9,
          9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
          9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 12, 12, 12,
         12, 12, 12, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 15,
         15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17,
         17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 20, 20, 20,
         20, 20, 20, 21, 21, 21

In [44]:
from torch_geometric.utils import dense_to_sparse
edge_idx = []
edge_att = []
for i in range(260):
    edge_index, edge_attr = dense_to_sparse(torch.from_numpy(mat_list[i]))
    edge_idx.append(edge_index)
    edge_att.append(edge_attr)

In [45]:
edge_indices = [i.numpy() for i in edge_idx]
edge_weights = [i.numpy() for i in edge_att]

In [46]:
features = []
targets = []
for i in range(260):
    features.append(normalize(X_y[i][0].numpy(), axis=1, norm='max'))
    #features.append(X_y[i][0].numpy())
    targets.append([1 if a > 0 else 0 for a in X_y[i][1].numpy()])
targets = np.asarray(targets)

In [47]:
targets[3]

array([0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,
       1, 0, 0, 0, 0, 0, 0])

In [48]:
mat_list[2].shape

(29, 29)

In [50]:
padded_features = []
for i in features:
    padded_features.append(np.pad(i, [(0, 0), (0, 5-i.shape[1])], 'mean'))

In [51]:
padded_features = np.asarray(padded_features)

In [52]:
padded_features.shape

(260, 29, 5)

In [53]:
comp_emb = np.asarray([comp_emb.numpy() for i in range(260)])

In [54]:
comp_emb.shape

(260, 29, 768)

In [55]:
padded_features = np.concatenate((padded_features, comp_emb), axis = 2)

In [56]:
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal

In [57]:
len(padded_features)

260

In [58]:
temporal_signal = DynamicGraphTemporalSignal(edge_indices = edge_indices , edge_weights = edge_weights, features = padded_features, targets = targets)

In [59]:
temporal_signal

<torch_geometric_temporal.signal.dynamic_graph_temporal_signal.DynamicGraphTemporalSignal at 0x22de69f7f10>

In [60]:
from torch_geometric_temporal.signal import temporal_signal_split

train_dataset, test_dataset = temporal_signal_split(temporal_signal, train_ratio=0.8)

In [61]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN
from torch_geometric_temporal import TemporalConv
from torch_geometric_temporal import EvolveGCNO
from torch_geometric_temporal import GConvGRU
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.evol = EvolveGCNO(node_features)
        self.recurrent = DCRNN(node_features, 16, 1)
        self.conv = GConvGRU(node_features, 64, 3)
        #self.linear = torch.nn.Linear(16, 1)
        self.linear = torch.nn.Linear(64, 2)
        self.dropout = torch.nn.Dropout(0.5)

    def forward(self, x, edge_index, edge_weight):
#        h = self.recurrent(x, edge_index, edge_weight)
#        h = self.dropout(h)
        h = self.conv(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        h = torch.sigmoid(h)
        return h

In [62]:
from tqdm import tqdm

model = RecurrentGCN(node_features = 773)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(200)):
    loss = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        loss += torch.nn.CrossEntropyLoss()(y_hat, snapshot.y.long())
#        loss += torch.mean((y_hat-snapshot.y)**2)
#        loss = loss / (time+1)
        
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

100%|██████████| 200/200 [28:51<00:00,  8.66s/it]


In [63]:
y_hat_l = []
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    #cost = cost + torch.mean((y_hat-snapshot.y)**2)
    y_hat_l.append(y_hat)
#cost = cost / (time+1)
#cost = cost.item()
#print("MSE: {:.4f}".format(cost))


In [64]:
y_hat_l

[tensor([[0.2463, 0.7511],
         [0.6364, 0.3697],
         [0.1962, 0.8088],
         [0.2867, 0.7180],
         [0.3509, 0.6552],
         [0.0968, 0.9059],
         [0.2954, 0.6922],
         [0.4180, 0.5805],
         [0.3080, 0.6721],
         [0.3419, 0.6665],
         [0.3282, 0.6772],
         [0.3104, 0.6879],
         [0.3382, 0.6655],
         [0.3535, 0.6595],
         [0.4030, 0.6023],
         [0.2265, 0.7795],
         [0.3195, 0.6725],
         [0.2009, 0.8185],
         [0.4203, 0.5928],
         [0.2005, 0.8159],
         [0.1237, 0.8772],
         [0.4547, 0.5522],
         [0.5536, 0.4180],
         [0.8471, 0.1337],
         [0.0891, 0.9063],
         [0.1991, 0.8201],
         [0.7791, 0.2162],
         [0.3796, 0.6143],
         [0.1670, 0.8344]], grad_fn=<SigmoidBackward>),
 tensor([[0.1557, 0.8454],
         [0.5533, 0.4568],
         [0.0758, 0.9287],
         [0.1706, 0.8340],
         [0.1782, 0.8324],
         [0.0193, 0.9826],
         [0.0698, 0.9306],

In [65]:
y_hat_l = [list(np.squeeze(i.detach().numpy())) for i in y_hat_l]
y_hat_l = [z for y in y_hat_l for z in y]

In [66]:
y_hat_l = [y[1] for y in y_hat_l]

In [67]:
y_hat_l

[0.7510622,
 0.36967477,
 0.80878544,
 0.71799827,
 0.6551511,
 0.905933,
 0.69224125,
 0.5805447,
 0.6720963,
 0.6664721,
 0.6772482,
 0.6879442,
 0.6654967,
 0.6595472,
 0.6022927,
 0.7794788,
 0.6725395,
 0.81850195,
 0.5928065,
 0.81586474,
 0.8772021,
 0.5521879,
 0.41796514,
 0.1336909,
 0.9063242,
 0.82005376,
 0.21621192,
 0.61428064,
 0.8344264,
 0.8454016,
 0.45681918,
 0.92868906,
 0.83400255,
 0.83235365,
 0.9825515,
 0.9306407,
 0.6481995,
 0.7520678,
 0.83850485,
 0.8249852,
 0.89410543,
 0.8994224,
 0.8210432,
 0.8170324,
 0.95012295,
 0.882835,
 0.9221026,
 0.83387864,
 0.15474838,
 0.97092664,
 0.61608803,
 0.48040196,
 0.5872975,
 0.9507781,
 0.877131,
 0.32451764,
 0.79722583,
 0.9030667,
 0.6679942,
 0.59206545,
 0.5807945,
 0.52944535,
 0.3348449,
 0.55904347,
 0.5457959,
 0.4610916,
 0.44583404,
 0.41582236,
 0.5098483,
 0.62035257,
 0.42747405,
 0.5328888,
 0.44028848,
 0.39611262,
 0.49759465,
 0.7437328,
 0.50086534,
 0.56331646,
 0.65032375,
 0.71813536,
 0.66

In [74]:
import numpy as np
from sklearn import metrics
y = true_label
pred = np.array(y_hat_l)
fpr, tpr, thresholds = metrics.roc_curve(y, pred)
metrics.auc(fpr, tpr)

0.6017447000805548

In [72]:
y_hat_list = [1 if x > 0.69 else 0 for x in y_hat_l]

In [69]:
true_label = []
for time, snapshot in enumerate(test_dataset):
    true_label.append(list(snapshot.y.detach().numpy()))

In [70]:
true_label = [int(z) for y in true_label for z in y]

In [73]:
from sklearn.metrics import classification_report
y_true = true_label
target_names = ['class 0', 'class 1']
print(classification_report(y_true, y_hat_list, target_names=target_names))

              precision    recall  f1-score   support

     class 0       0.50      0.60      0.55       661
     class 1       0.63      0.54      0.58       847

    accuracy                           0.56      1508
   macro avg       0.57      0.57      0.56      1508
weighted avg       0.58      0.56      0.57      1508

