In [None]:
from torch_geometric.datasets import Planetoid
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import numpy as np
import copy
import time 
import matplotlib.pyplot as plt
import pickle # Lokales Speichern von Objekten
import keyboard

from GNM_Toolbox.tools.tools import *
from GNM_Toolbox.gnm import *
from GNM_Toolbox.data.dataloader import *

dataset = load_dataset('Cora')

# Some Helping Functions

In [None]:
# Gegeben sei eine target_list (a_0, a_1, a_2, ...)
# und eine out_list ((b_0, x), (b_1, x), (b_2, x), (b_3, x), ...)
# Gesucht wird eine Liste l von Indizes, sodass für i < len(target_list): abs(target_list[i] - out_list[l[i]][0]) minimal ist
def find_each_nearest(target_list, out_list):
    # Each list is expected to be sorted
    i, j = 0, 0
    result = list()
    while True:
        diff_0 = abs(target_list[i] - out_list[j][0])
        diff_1 = abs(target_list[i] - out_list[j+1][0])
        
        if diff_0 >= diff_1:
            j += 1
        elif diff_0 < diff_1:
            result.append(j)
            i += 1
        if i >= len(target_list):
            return result
        if j+1 >= len(out_list):
            while i < len(target_list):
                result.append(j)
                i += 1
            return result
            
def get_best_values_indices(targets, lambdas):
    lambdas.sort(key = lambda x: x[0])
    return find_each_nearest(targets, lambdas)

def h(x):
    a0, a1, a2, a3 = 13, 4, 15, 15
    return torch.exp(torch.sum(x, dim=1)/a0 - a1) - ((torch.sum(x, dim=1) - a2 ) / a3)
    
def pi_test(X, y):
    a0, a1, a2 = -torch.log(torch.tensor(35.)), 1, 1.6
    return torch.sigmoid(a0 + a1 * h(X) + a2 * y)

def pi_complicated(X, y):
    tmp = torch.sum(X,axis=1) 
    h = torch.exp(tmp/13-4)-(tmp-15)/15
    pi = 1/(1+35*np.exp(h[:]-1.6*y[:]))
    return pi

def pi_simple(x, y):
    a = 0
    b = 1
    return torch.sigmoid(a + b*y)

def create_mask_from_pi(data, pi):
    p = pi(data.x, data.y)
    mask = torch.tensor((np.random.binomial(size = p.shape[0], n = 1, p = p) == 1))        
    return mask.bool()

def split_known_mask_into_val_and_train_mask(known, ratio=0.8):
    val_mask = torch.zeros_like(known) == 1
    train_mask = torch.zeros_like(known) == 1
    for i in range(len(known)):
        if known[i] == True:
            if np.random.binomial(1, ratio) == 1:
                train_mask[i] = True
            else:
                val_mask[i] = True
    return val_mask, train_mask

def calculate_lambda(train_mask, y):
    a = 0 # Anzahl an Klasse 0
    b = 0 # Anzahl an Klasse 1
    for yy in y[train_mask]:
        if yy == 0:
            a += 1
        elif yy == 1:
            b += 1
    return b/a
        
def insert_into_list(l, item, t):
    # l list, i item to insert, target
    def diff(a, b):
        return abs(a-b)
    N = len(l)
    if N == 0:
        l.insert(0, item)
        return
    d = diff(t, item[0])
    d_0 = diff(t, l[0][0])
    if d <= d_0:
        l.insert(0, item)
        return
    for i in range(N-1):
        d_0 = diff(t, l[i][0])
        d_1 = diff(t, l[i+1][0])
        if d_0 <= d and d <= d_1:
            l.insert(i+1, item)
            return
    l.append(item)

# Data Setup

In [None]:
# Set up data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = dataset[0].to(device)
data.num_classes = 2
# Klassen 0,1,2,4,5,6 werden zu Klasse 1, Klasse 3 wird zu Klasse 0
y = torch.zeros_like(data.y)
y[data.y == 3] = 1
data.y = y

## Analysis if Loss Function $L_1$ ist actually better than NLL

In [None]:
# Evaluate actuall pi
pi_true = pi_test(data.x, data.y)

# Load masks
all_masks,_,_ = pickle_read('m_masks.pkl')
subset = [1.2, 1.7, 2.2]
choosen_masks = {k: all_masks[k] for k in subset}

# Drei Masken: 1.2, 1.7, 2.2
# Acht Noiselevel: 0.0025, 0.005, 0.0075, 0.01, 0.0125, 0.015, 0.0175, 0.02

In [None]:
# Eval 'Perfect' Loss, 'Perfect' Loss with noise and NLL
IT_per_mask = 4
NB_masks = len(choosen_masks[1.7])
M = len(choosen_masks)
t_0 = time.time()
noise_levels = [0, 0.0025, 0.005, 0.0075, 0.01, 0.0125, 0.015, 0.0175, 0.02]
all_models = dict()

# Iteriere über Masken
for i, l in enumerate(choosen_masks):
    sms_models = list() # SM Standard
    smn_models = [list() for _ in range(9)] # SM Advanced with noises
    
    for j, mask_tupel in enumerate(choosen_masks[l]):
        _, train_mask, val_mask = mask_tupel

        # Trainiere jeweils N Modelle
        for k in range(IT_per_mask):
            print_status(i * NB_masks * IT_per_mask + j * IT_per_mask + k, M * NB_masks * IT_per_mask, t_0)
            noise = torch.randn_like(pi_true)
            for m, noise_level in enumerate(noise_levels):
                pi = pi_true + noise_level * noise
                smn_models[m].append((*train_one_net(data, 
                                                  train_mask, 
                                                  val_mask,
                                                  loss_function=weighted_categorial_crossentropy_loss(pi[train_mask], reduction='mean'),
                                                  val_loss_function=weighted_categorial_crossentropy_loss(pi[val_mask], reduction='mean'))[1:], j))
            sms_models.append((*train_one_net(data, train_mask, val_mask)[1:], j))
    all_models[l] = (sms_models, *smn_models)
    #pickle_write('l1_analysis-noise-{}.pkl'.format(i), all_models)
    
pickle_write('l1_analysis-noise-final-2.pkl', all_models)