In [648]:
from torch_geometric.data import InMemoryDataset, Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import Sequential, GINConv, GATConv, global_mean_pool, global_max_pool, global_add_pool
from torch_geometric.utils import to_networkx, from_networkx, add_self_loops

from torch.nn import Linear, BatchNorm1d, ReLU, Dropout, LeakyReLU, SiLU, PReLU
import torch.nn.functional as F
import torch
from torch.utils.tensorboard import  SummaryWriter

import numpy as np
import networkx as nx
import pandas as pd
import pickle
import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix


from nba_api.stats.static import teams,players

import traceback
import difflib
import os

In [649]:
try:
    df_nba_teams = pd.read_csv('nba_teams_id.csv')
except:
    # Get all the NBA teams description
    nba_teams = teams.get_teams()
    df_nba_teams = pd.DataFrame(nba_teams)
    df_nba_teams.to_csv('nba_teams_id.csv',index=False)

df_nba_teams.head()

Unnamed: 0,id,full_name,abbreviation,nickname,city,state,year_founded
0,1610612737,Atlanta Hawks,ATL,Hawks,Atlanta,Georgia,1949
1,1610612738,Boston Celtics,BOS,Celtics,Boston,Massachusetts,1946
2,1610612739,Cleveland Cavaliers,CLE,Cavaliers,Cleveland,Ohio,1970
3,1610612740,New Orleans Pelicans,NOP,Pelicans,New Orleans,Louisiana,2002
4,1610612741,Chicago Bulls,CHI,Bulls,Chicago,Illinois,1966


In [650]:
team_id = 1610612741
df_nba_teams.query(f"id == {team_id}")['full_name'].values[0]

'Chicago Bulls'

In [651]:
# Get all the NBA players description
try:
    df_nba_players = pd.read_csv('nba_players.csv')
except:
    df_nba_players = pd.DataFrame(players.get_players())
    df_nba_players.to_csv('nba_players.csv',index=False)

df_nba_players.head()

Unnamed: 0,id,full_name,first_name,last_name,is_active
0,76001,Alaa Abdelnaby,Alaa,Abdelnaby,False
1,76002,Zaid Abdul-Aziz,Zaid,Abdul-Aziz,False
2,76003,Kareem Abdul-Jabbar,Kareem,Abdul-Jabbar,False
3,51,Mahmoud Abdul-Rauf,Mahmoud,Abdul-Rauf,False
4,1505,Tariq Abdul-Wahad,Tariq,Abdul-Wahad,False


In [652]:
df_ratings_NBA2k_2020_2023 = pd.read_csv('ratings_NBA2k_2020_2023.csv')
df_ratings_NBA2k_2020_2023.head()

Unnamed: 0.1,Unnamed: 0,Season_year,Team,Player,Shot Close,Shot Mid,Shot 3pt,Shot IQ,Free Throw,Offensive Consistency,...,Help Defense IQ,Lateral Quickness,Pass Perception,Steal,Block,Defensive Consistency,Offensive Rebound,Defensive Rebound,Intangibles,Potential
0,0,2020,Atlanta Hawks,Trae Young,83,77,80,90,81,96,...,56,68,68,39,35,28,36,48,98,89
1,1,2020,Atlanta Hawks,John Collins,69,67,72,90,70,89,...,68,51,32,26,51,54,80,75,95,86
2,2,2020,Atlanta Hawks,Jabari Parker,81,70,69,90,66,75,...,64,62,63,36,50,50,44,75,95,78
3,3,2020,Atlanta Hawks,Kevin Huerter,72,70,80,90,71,63,...,61,70,62,41,49,33,37,47,95,82
4,4,2020,Atlanta Hawks,De'Andre Hunter,62,71,74,90,70,42,...,72,80,52,32,46,72,47,51,95,88


In [653]:
df_players_formation = pd.read_csv('players_formation.csv')
df_players_formation.head()

Unnamed: 0.1,Unnamed: 0,GAME_ID,TEAM_ID,PLAYER_ID,PLAYER_NAME,START_POSITION
0,0,22001066,1610612745,1630256,Jae'Sean Tate,F
1,1,22001066,1610612745,1630231,Kenyon Martin Jr.,F
2,2,22001066,1610612745,203482,Kelly Olynyk,C
3,3,22001066,1610612745,1630237,Anthony Lamb,G
4,4,22001066,1610612745,201571,D.J. Augustin,G


In [654]:
with open('season_campaign_all_teams_dict.pickle','rb') as results:
    season_campaign_all_teams_dict = pickle.load(results)

season_campaign_all_teams_dict['2020-21'][1610612766][1610612766]

Unnamed: 0,SEASON_ID,TEAM_ID,TEAM_ABBREVIATION,TEAM_NAME,GAME_ID,GAME_DATE,MATCHUP,WL,MIN,PTS,...,DREB,REB,AST,STL,BLK,TOV,PF,PLUS_MINUS,TEAM_ADVERSARY_ABBREVIATION,WINNER
0,22020,1610612766,CHA,Charlotte Hornets,0022001080,2021-05-16,CHA @ WAS,L,240,110,...,32,42,25,8,5,11,15,-5.0,WAS,False
1,22020,1610612766,CHA,Charlotte Hornets,0022001064,2021-05-15,CHA @ NYK,L,265,109,...,36,51,28,5,2,8,18,-9.0,NYK,False
2,22020,1610612766,CHA,Charlotte Hornets,0022001047,2021-05-13,CHA vs. LAC,L,239,90,...,27,35,24,9,1,10,15,-23.0,LAC,False
3,22020,1610612766,CHA,Charlotte Hornets,0022000448,2021-05-11,CHA vs. DEN,L,240,112,...,34,49,30,9,5,17,17,-5.0,DEN,False
4,22020,1610612766,CHA,Charlotte Hornets,0022001020,2021-05-09,CHA vs. NOP,L,240,110,...,35,46,22,6,8,17,19,-2.0,NOP,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
67,22020,1610612766,CHA,Charlotte Hornets,0022000069,2021-01-01,CHA vs. MEM,L,241,93,...,32,42,23,9,4,18,15,-15.0,MEM,False
68,22020,1610612766,CHA,Charlotte Hornets,0022000059,2020-12-30,CHA @ DAL,W,240,118,...,38,50,30,6,6,23,28,19.0,DAL,True
69,22020,1610612766,CHA,Charlotte Hornets,0022000032,2020-12-27,CHA vs. BKN,W,241,106,...,39,52,35,8,3,12,25,2.0,BKN,True
70,22020,1610612766,CHA,Charlotte Hornets,0022000022,2020-12-26,CHA vs. OKC,L,239,107,...,34,47,27,11,8,13,22,-2.0,OKC,False


In [655]:
count=0
for y,d1 in season_campaign_all_teams_dict.items():
    for k1,d2 in d1.items():
        for k2,d3 in d2.items():
            count+=1
            # print(y,k2)
count


90

In [656]:
df_ratings_NBA2k_2020_2023.head()

Unnamed: 0.1,Unnamed: 0,Season_year,Team,Player,Shot Close,Shot Mid,Shot 3pt,Shot IQ,Free Throw,Offensive Consistency,...,Help Defense IQ,Lateral Quickness,Pass Perception,Steal,Block,Defensive Consistency,Offensive Rebound,Defensive Rebound,Intangibles,Potential
0,0,2020,Atlanta Hawks,Trae Young,83,77,80,90,81,96,...,56,68,68,39,35,28,36,48,98,89
1,1,2020,Atlanta Hawks,John Collins,69,67,72,90,70,89,...,68,51,32,26,51,54,80,75,95,86
2,2,2020,Atlanta Hawks,Jabari Parker,81,70,69,90,66,75,...,64,62,63,36,50,50,44,75,95,78
3,3,2020,Atlanta Hawks,Kevin Huerter,72,70,80,90,71,63,...,61,70,62,41,49,33,37,47,95,82
4,4,2020,Atlanta Hawks,De'Andre Hunter,62,71,74,90,70,42,...,72,80,52,32,46,72,47,51,95,88


In [657]:
def get_team_name(team_id):
    return df_nba_teams.query(f"id == {team_id}")['full_name'].values[0]

def create_fully_connected_graph_new(G_completed_team, game_id, team_id, show_graph, season_year, team_name, df_ratings_NBA2k_2020_2023_PLAYER_ID, team_side:int):

    df_rost = df_players_formation.query(f"GAME_ID == {int(game_id)} & TEAM_ID == {team_id}")[['PLAYER_NAME','START_POSITION']]

    reserve_players = df_rost['PLAYER_NAME'][df_rost['START_POSITION'].isna().to_list()].values
    starting_lineup_players = df_rost['PLAYER_NAME'][np.invert(df_rost['START_POSITION'].isna().values)].values

    rating_embeddings_new = []
    for player_name in reserve_players:
        embedding = rating_embeddings(df_ratings_NBA2k_2020_2023_PLAYER_ID,season_year,team_name,player_name)

        rating_embeddings_new.append((player_name,{'x':embedding,'starting_lineup':False,'team_side':team_side}))

    G_completed_team.add_nodes_from(rating_embeddings_new)


    rating_embeddings_new = []
    for player_name in starting_lineup_players:
        embedding = rating_embeddings(df_ratings_NBA2k_2020_2023_PLAYER_ID,season_year,team_name,player_name)
        rating_embeddings_new.append((player_name,{'x':embedding,'starting_lineup':True,'team_side':team_side}))

    G_completed_team.add_nodes_from(rating_embeddings_new)


    # assert sum(nx.get_node_attributes(G_completed_team,'starting_lineup').values()) == 5, "should be 5 players is in the stating lineup"

    G_completed_team_full_connected = nx.complete_graph(G_completed_team.nodes)
    G_completed_team = nx.compose(G_completed_team, G_completed_team_full_connected)

    # nx.draw(G_completed_team, with_labels=True, font_weight='bold',node_color="red")

    G_completed_team = nx.compose(G_completed_team, G_completed_team_full_connected)

    if (show_graph):
        starting_lineup_colors = list(nx.get_node_attributes(G_completed_team,'starting_lineup').values())
        nx.draw_circular(G_completed_team, with_labels=True, font_weight='bold',node_color=['red' if x else 'blue' for x in starting_lineup_colors])
        plt.show()

    return G_completed_team, starting_lineup_players, df_rost

# def create_fully_connected_graph(G_completed_team, game_id, team_id, show_graph, season_year, team_name):

#     player_name=player_name
    
    
#     df_rost = df_players_formation.query(f"GAME_ID == {int(game_id)} & TEAM_ID == {team_id}")[['PLAYER_NAME','START_POSITION']]

#     reserve_players = df_rost['PLAYER_NAME'][df_rost['START_POSITION'].isna().to_list()].values
#     starting_lineup_players = df_rost['PLAYER_NAME'][np.invert(df_rost['START_POSITION'].isna().values)].values
#     G_completed_team.add_nodes_from(reserve_players,starting_lineup=False)
#     G_completed_team.add_nodes_from(starting_lineup_players,starting_lineup=True)


#     # assert sum(nx.get_node_attributes(G_completed_team,'starting_lineup').values()) == 5, "should be 5 players is in the stating lineup"

#     G_completed_team_full_connected = nx.complete_graph(G_completed_team.nodes)
#     G_completed_team = nx.compose(G_completed_team, G_completed_team_full_connected)

#     # nx.draw(G_completed_team, with_labels=True, font_weight='bold',node_color="red")

#     G_completed_team = nx.compose(G_completed_team, G_completed_team_full_connected)

#     if (show_graph):
#         starting_lineup_colors = list(nx.get_node_attributes(G_completed_team,'starting_lineup').values())
#         nx.draw_circular(G_completed_team, with_labels=True, font_weight='bold',node_color=['red' if x else 'blue' for x in starting_lineup_colors])
#         plt.show()

#     return G_completed_team, starting_lineup_players



# check what payers in ratings_NBA2k_2020_2023.csv doesn't have the respective PLAYER_ID of the players_formation.csv
# df_ratings_NBA2k_2020_2023_PLAYER_ID[df_ratings_NBA2k_2020_2023_PLAYER_ID['PLAYER_ID'].map(lambda x:np.isnan(x))]

df_ratings_NBA2k_2020_2023_PLAYER_ID = pd.merge(df_ratings_NBA2k_2020_2023,df_nba_players[['full_name','id']], how='left', copy=False, left_on='Player',right_on='full_name')
def rating_embeddings(df_ratings_NBA2k_2020_2023_PLAYER_ID, season_year,team_name,player_name):
    rating_columns = ['Shot Close', 'Shot Mid','Shot 3pt', 'Shot IQ', 'Free Throw', 'Offensive Consistency','Driving Layup', 'Standing Dunk', 'Driving Dunk', 'Draw Foul','Post Moves', 'Post Hook', 'Post Fade', 'Hands', 'Speed With Ball','Ball Handle', 'Passing Accuracy', 'Passing Vision', 'Passing IQ','Speed', 'Acceleration', 'Vertical', 'Strength', 'Stamina', 'Hustle','Interior Defense', 'Perimeter Defense', 'Help Defense IQ','Lateral Quickness', 'Pass Perception', 'Steal', 'Block','Defensive Consistency', 'Offensive Rebound', 'Defensive Rebound','Intangibles', 'Potential']
    
    try:
        # return torch.tensor(df_ratings_NBA2k_2020_2023_PLAYER_ID.query(f"Season_year == {season_year} and Team == '{team_name}' and Player == '{player_name}'")[rating_columns].values[0])
        return df_ratings_NBA2k_2020_2023_PLAYER_ID.query(f"Season_year == {season_year} and Team == '{team_name}' and Player == '{player_name}'")[rating_columns].values[0]
    except:
        pass
    try:
        # return torch.tensor(df_ratings_NBA2k_2020_2023_PLAYER_ID.query(f"Team == '{team_name}' and Player == '{player_name}'")[rating_columns].values[0])
        return df_ratings_NBA2k_2020_2023_PLAYER_ID.query(f"Team == '{team_name}' and Player == '{player_name}'")[rating_columns].values[0]
    except:
        pass

    # try the closest name
    try:
        player_name_closed = difflib.get_close_matches(player_name,df_ratings_NBA2k_2020_2023_PLAYER_ID['Player'],1)[0]
        # return torch.tensor(df_ratings_NBA2k_2020_2023_PLAYER_ID.query(f"Player == '{player_name_closed}'")[rating_columns].values[0])
        return df_ratings_NBA2k_2020_2023_PLAYER_ID.query(f"Player == '{player_name_closed}'")[rating_columns].values[0]
    except:
        pass

    # return zeros if the player cannot be found
    return np.zeros([len(rating_columns)], dtype=np.int64)

def train(model, data_train_loader,criterion, optimizer, device='cpu'):
    try:
        for data_train in data_train_loader:
            # data_train = data_train.to(device)
            optimizer.zero_grad()  # Clear gradients.
            out = model(
                x=data_train.x,
                edge_index=data_train.edge_index,
                batch=data_train.batch,
                team_side=data_train.team_side,
                device=device
            )
        
            loss = criterion(out, data_train.y)
            # print(loss)
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.
    except:
        traceback.print_exc()
        
def test(model, data_train_loader,criterion, device='cpu'):
    pred_y_total = torch.tensor([], device=device)
    actual_y_total = torch.tensor([], device=device)
    valid_losses = []
    for data_train in data_train_loader:
        # data_train = data_train.to(device)
        out = model(
            x=data_train.x,
            edge_index=data_train.edge_index,
            batch=data_train.batch,
            team_side=data_train.team_side,
            device=device
        )


        pred_y = out.argmax(dim=1)  # Use the class with highest probability.
        pred_y_total = torch.cat([pred_y_total,pred_y])        
        actual_y = data_train.y.squeeze()       
        actual_y_total = torch.cat([actual_y_total,actual_y])


        loss = criterion(out, data_train.y)
        valid_losses.append(loss.item()) # record validation loss

    valid_losses = np.average(valid_losses)
    
    actual_y_total_cpu = actual_y_total.cpu()
    pred_y_total_cpu = pred_y_total.cpu()
    acc = accuracy_score(y_true=actual_y_total_cpu, y_pred=pred_y_total_cpu)
    precision = precision_score(y_true=actual_y_total_cpu, y_pred=pred_y_total_cpu, average=None, zero_division=0)
    precision_avg = precision_score(y_true=actual_y_total_cpu, y_pred=pred_y_total_cpu, average='weighted', zero_division=0)
    recall = recall_score(y_true=actual_y_total_cpu, y_pred=pred_y_total_cpu, average=None,zero_division=0)        
    recall_avg = recall_score(y_true=actual_y_total_cpu, y_pred=pred_y_total_cpu, average='weighted',zero_division=0)
    f1_score_result = f1_score(y_true=actual_y_total_cpu, y_pred=pred_y_total_cpu, average=None, zero_division=0)
    f1_score_result_avg = f1_score(y_true=actual_y_total_cpu, y_pred=pred_y_total_cpu, average='weighted', zero_division=0)
    c_matrix = confusion_matrix(y_true=actual_y_total_cpu, y_pred=pred_y_total_cpu)

    return acc, valid_losses, precision, precision_avg, recall, recall_avg, f1_score_result, f1_score_result_avg, c_matrix, actual_y_total_cpu, pred_y_total_cpu
    


In [658]:
class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, dataset_name, season_list, df_nba_teams, season_campaign_all_teams_dict, transform=None, pre_transform=None, pre_filter=None):
        # Initialize the dataset with some arguments
        self.dataset_name = dataset_name
        self.season_list = season_list
        self.df_nba_teams = df_nba_teams
        self.season_campaign_all_teams_dict = season_campaign_all_teams_dict
        super(MyOwnDataset, self).__init__(root, transform, pre_transform,pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def raw_file_names(self):
        # Return a list of the names of the raw data files
        # return os.listdir(self.raw_dir)     
        return ['file_gnn']  

    @property
    def processed_file_names(self):
        # Return a list of the names of the processed data files
        return [f'{self.dataset_name}_{i}.pt' for i in range(len(self.raw_file_names))]

    def process(self):
        # Process your raw data files and convert them to Data objects
        # Save the processed data to disk

        verbose = False
        G_completed_team_1_2_list = []
        for season in self.season_list:
            print(season)
            for team_1_id, team_1_name in zip(self.df_nba_teams['id'],self.df_nba_teams['full_name']):        
                
                df_season_campaign_all_teams_dict = self.season_campaign_all_teams_dict[season][team_1_id][team_1_id]

                for team_2,game_id,winner in zip(df_season_campaign_all_teams_dict['TEAM_ADVERSARY_ABBREVIATION'],df_season_campaign_all_teams_dict['GAME_ID'],df_season_campaign_all_teams_dict['WINNER']):
                    team_2_id = self.df_nba_teams.query(f"abbreviation == '{team_2}'")['id'].values[0]

                    season_year=int(season[:4])

                    G_completed_team_1 = nx.Graph()
                    G_completed_team_1, starting_lineup_players_team_1, df_rost_new = create_fully_connected_graph_new(G_completed_team=G_completed_team_1, game_id=game_id, team_id=team_1_id, show_graph=verbose, season_year=season_year,team_name=team_1_name,df_ratings_NBA2k_2020_2023_PLAYER_ID=df_ratings_NBA2k_2020_2023_PLAYER_ID,team_side=0)

                    G_completed_team_2 = nx.Graph()
                    team_2_name = get_team_name(team_2_id)
                    G_completed_team_2, starting_lineup_players_team_2, df_rost_new = create_fully_connected_graph_new(G_completed_team=G_completed_team_2, game_id=game_id, team_id=team_2_id, show_graph=verbose, season_year=season_year,team_name=team_2_name,df_ratings_NBA2k_2020_2023_PLAYER_ID=df_ratings_NBA2k_2020_2023_PLAYER_ID,team_side=1)
                    
                    G_completed_team_1_2 = nx.compose(G_completed_team_1,G_completed_team_2)

                    if (verbose):
                        starting_lineup_colors = list(nx.get_node_attributes(G_completed_team_1_2,'starting_lineup').values())                    
                        nx.draw_circular(G_completed_team_1_2, with_labels=True, font_weight='bold',node_color=['red' if x else 'blue' for x in starting_lineup_colors])
                        plt.show()

                    starting_lineup_players_fully_connected = G_completed_team_1_2.subgraph(np.concatenate([starting_lineup_players_team_1,starting_lineup_players_team_2]))
                    starting_lineup_players_fully_connected = nx.complete_graph(starting_lineup_players_fully_connected)

                    G_completed_team_1_2 = nx.compose(G_completed_team_1_2, starting_lineup_players_fully_connected)
                    if (verbose):
                        starting_lineup_colors = list(nx.get_node_attributes(G_completed_team_1_2,'starting_lineup').values())
                        nx.draw_circular(G_completed_team_1_2, with_labels=True, font_weight='bold',node_color=['red' if x else 'blue' for x in starting_lineup_colors])
                        plt.show()

                    graph_instance = from_networkx(G_completed_team_1_2)                    
                    graph_instance.y = int(winner)
                    # update the list of instances
                    G_completed_team_1_2_list.append(graph_instance)

                    # print(f"Season: {season_year} | {team_1_id} : {team_1_name} | {team_2_id} : {team_2_name}")        

        # if self.pre_filter is not None:
        #     self.data = [data for data in self.data if self.pre_filter(data)]

        # if self.pre_transform is not None:
        #     self.data = [self.pre_transform(data) for data in self.data]

        data, slices = self.collate(G_completed_team_1_2_list)

        torch.save((data, slices), self.processed_paths[0])

    # def len(self):
    #     # Return the number of graphs in the dataset
    #     return len(self.processed_file_names)

    # def get(self, idx):
    #     # Return the Data object at index idx
    #     data = torch.load(os.path.join(self.processed_dir, f'{self.dataset_name}_{idx}.pt'))
    #     return data


In [659]:
class GNN_model(torch.nn.Module):
    def __init__(self,in_channels,out_channels,gfunc, dropout,num_classes):
        super(GNN_model, self).__init__()
        self.in_channels = in_channels
        self.gfunc = gfunc
        self.n_global_functions = len(self.gfunc)
        self.lin_global = Linear(in_channels * self.n_global_functions * 2, num_classes, bias=True)
        self.dropout = dropout
        self.gnn_layer = Sequential('x, edge_index',[
            (GINConv(
                torch.nn.Sequential(
                    Linear(in_channels, in_channels),
                    BatchNorm1d(in_channels),
                    PReLU(in_channels),
                    Linear(in_channels, in_channels),
                    PReLU(in_channels),
                )
            ),'x, edge_index -> h_hat'),
            (BatchNorm1d(in_channels),'h_hat -> h_hat'),
            (Linear(in_channels, in_channels, bias=True),'h_hat -> h_hat'),
            PReLU(in_channels),
        ])
            # (GATConv(in_channels,out_channels),'h_hat -> h_hat'),
            # (GATConv(in_channels,out_channels,dropout=self.dropout),'h_hat -> h_hat'),
        # self.norm = BatchNorm1d(in_channels)
            # (BatchNorm1d(in_channels),'h_hat -> h_hat'),
        
        self.prelu = PReLU(out_channels)

        self.gat = GATConv(in_channels,in_channels,dropout=self.dropout)
    
    def readout_vector(self,h,batch,gfunc,device='cpu'):

        global_vector = torch.tensor([]).to(device)

        for g in gfunc:
            global_vector = torch.cat([global_vector,g(h,batch).to(device)], dim=1).to(device)
        return global_vector
        
    def forward(self, x, edge_index, batch, team_side, device='cpu'):
        
        alpha = .5
        h1 = self.gnn_layer(x,edge_index)
        h1 = h1 * alpha + x * (1-alpha)

        h2 = self.gat(h1,edge_index)
        h2 = h2 * alpha + h1 * (1-alpha)

        h_global_pool = torch.tensor([])
        for batch_id in batch.unique():
            players_game = h2[batch == batch_id]
            team_0 = team_side[batch == batch_id]==0
            team_1 = team_side[batch == batch_id]==1
            # g_0 = global_mean_pool(players_game[team_0],torch.zeros(len(players_game[team_0]),dtype=int))
            # g_1 = global_mean_pool(players_game[team_1],torch.zeros(len(players_game[team_1]),dtype=int))
            g_0 = self.readout_vector(players_game[team_0],torch.zeros(len(players_game[team_0]),dtype=int),self.gfunc,device=device).squeeze()
            g_1 = self.readout_vector(players_game[team_1],torch.zeros(len(players_game[team_1]),dtype=int),self.gfunc,device=device).squeeze()
            h_global_pool = torch.cat([h_global_pool,torch.cat([g_0,g_1]).reshape(1,-1)])


        # 2. Readout layer
        # h_global_mean_pool = self.readout_vector(h2,batch,self.gfunc,device=device).squeeze()
        # y_pred = self.prelu(self.lin_global(h_global_mean_pool))

        y_pred = self.prelu(self.lin_global(h_global_pool))
        y_pred = F.dropout(y_pred, p=self.dropout, training=self.training)

        return y_pred

In [660]:
season_list = []
for year in range(3):
    season_list.append('202'+str(year)+'-2'+str(year+1))

graph_dataset = MyOwnDataset(
    root='.',
    dataset_name='nba_graph_dataset_v2',
    season_list=season_list,
    df_nba_teams=df_nba_teams,
    season_campaign_all_teams_dict=season_campaign_all_teams_dict
)

graph_dataset.edge_index = add_self_loops(graph_dataset.edge_index)[0]


train_size = int(graph_dataset.len()*.6)
val_size = int(graph_dataset.len()*.2)
test_size = int(graph_dataset.len()) - train_size - val_size

data_train_loader = DataLoader(
    dataset=graph_dataset[:train_size],
    batch_size=30
)
data_val_loader = DataLoader(
    dataset=graph_dataset[train_size:train_size+val_size],
    batch_size=30
)
data_test_loader = DataLoader(
    dataset=graph_dataset[train_size+val_size:],
    batch_size=30
)

In [661]:
# data_train_loader = DataLoader(G_completed_team_1_2_list,30)
# in_channels = G_completed_team_1_2_list[0].x[0].shape[0]
in_channels = 37
model = GNN_model(
    in_channels=in_channels,
    out_channels=1,
    gfunc=[global_mean_pool, global_max_pool, global_add_pool],
    dropout=.7,
    num_classes=graph_dataset.num_classes,
)

print(model)

epoch_number = 100
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
lr = 5e-4
weight_decay = 5e-2

path_writer = lambda scenario: "runs/scenario_{}".format(scenario)

parameters_random_search = {
    'lr':lambda :np.random.choice([5e-2,5e-3,5e-4,5e-5]),
    'weight_decay':lambda :np.random.choice([5e-2,5e-3,5e-4,5e-5]),
}

def create_random_combinations(parameters_random_search, total_number_scenarios):
    hyperparameter_items = [{scenario:parameters_random_search[scenario]() for scenario in parameters_random_search} for number_scenarios in range(total_number_scenarios)]
    return hyperparameter_items

scenarios_search = create_random_combinations(parameters_random_search,20)

list_main_kpis_epoch_scenario = []
for scenario in range(len(scenarios_search)):  
    print("Scenario {}, Hyperparameters: {}".format(scenario,scenarios_search[scenario]))
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=scenarios_search[scenario]['lr'],
        weight_decay=scenarios_search[scenario]['weight_decay']
    )
    # optimizer = torch.optim.Adam(
    #     model.parameters(),
    #     lr=lr,
    #     weight_decay=weight_decay
    # )

    writer = SummaryWriter(path_writer(scenario))

    list_main_kpis_epoch = []
    for epoch in range(1, epoch_number+1):    
        print(f'Epoch: {epoch}')
        train(
            model=model,
            data_train_loader=data_train_loader,
            criterion=criterion,
            optimizer=optimizer,
            )
        acc_test, losses_test, precision_test, precision_avg_test, recall_test, recall_avg_test, f1_score_result_test, f1_score_result_avg_test, c_matrix_test, actual_y_total_cpu, pred_y_total_cpu = test(
            model=model,
            data_train_loader=data_val_loader,
            criterion=criterion,
        )


        dict_main_kpis_epoch = {
            'acc_test':acc_test,
            'losses_test':losses_test,
            'precision_test':precision_test,
            'precision_avg_test':precision_avg_test,
            'recall_test':recall_test,
            'recall_avg_test':recall_avg_test,
            'f1_score_result_test':f1_score_result_test,
            'f1_score_result_avg_test':f1_score_result_avg_test,
            'c_matrix_test':c_matrix_test,
            'actual_y_total_cpu': actual_y_total_cpu,
            'pred_y_total_cpu': pred_y_total_cpu
        }

        print(dict_main_kpis_epoch)
        writer.add_scalar('acc_test',acc_test,epoch)
        writer.add_scalar('losses_test',losses_test,epoch)
        writer.add_scalar('precision_avg_test',precision_avg_test,epoch)
        writer.add_scalar('recall_avg_test',recall_avg_test,epoch)
        writer.add_scalar('f1_score_result_avg_test',f1_score_result_avg_test,epoch)        


        list_main_kpis_epoch.append(dict_main_kpis_epoch)

    writer.add_hparams(
        {
            'lr':scenarios_search[scenario]['lr'],
            'weight_decay':scenarios_search[scenario]['weight_decay']
        },
        {
            'acc_test':acc_test,
            'losses_test':losses_test,
            'precision_avg_test':precision_avg_test,
            'recall_avg_test':recall_avg_test,
            'f1_score_result_avg_test':f1_score_result_avg_test,
        }
    )

    list_main_kpis_epoch_scenario.append(list_main_kpis_epoch)

GNN_model(
  (lin_global): Linear(in_features=222, out_features=2, bias=True)
  (gnn_layer): Sequential(
    (0): GINConv(nn=Sequential(
    (0): Linear(in_features=37, out_features=37, bias=True)
    (1): BatchNorm1d(37, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=37)
    (3): Linear(in_features=37, out_features=37, bias=True)
    (4): PReLU(num_parameters=37)
  ))
    (1): BatchNorm1d(37, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Linear(in_features=37, out_features=37, bias=True)
    (3): PReLU(num_parameters=37)
  )
  (prelu): PReLU(num_parameters=1)
  (gat): GATConv(37, 37, heads=1)
)
Scenario 0, Hyperparameters: {'lr': 0.0005, 'weight_decay': 0.05}
Epoch: 1
{'acc_test': 0.4766949152542373, 'losses_test': 4.028679231802623, 'precision_test': array([0.47174254, 0.52112676]), 'precision_avg_test': 0.49779481039711915, 'recall_test': array([0.89835575, 0.09906292]), 'recall_avg_test': 0.4766949152542373, 