In [6]:
import os
import sys
from tqdm import tqdm
import pprint
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import umap

from sklearn.metrics import accuracy_score, balanced_accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

sys.path.append('..')
from semilearn.core.utils import get_net_builder, get_dataset, over_write_args_from_file
from semilearn.algorithms.openmatch.openmatch import OpenMatchNet
from semilearn.algorithms.iomatch.iomatch import IOMatchNet

In [7]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--c', type=str, default='')

def load_model_at(step='best'):
    args.step = step
    if step == 'best':
        args.load_path = '/'.join(args.load_path.split('/')[1:-1]) + "/model_best.pth"
    if step == 'latest':
        #args.load_path = '/'.join(args.load_path.split('/')[:-1]) + "/model_at_{args.step}_step.pth"
        args.load_path = '/'.join(args.load_path.split('/')[:-1]) + "/latest_model.pth"
    else:
        args.load_path = '/'.join(args.load_path.split('/')[:-1]) + "/"+str(step)+"_model.pth"
        
    print(args.load_path)
    checkpoint_path = os.path.join(args.load_path)
    checkpoint = torch.load(checkpoint_path)
    load_model = checkpoint['ema_model']
    load_state_dict = {}
    for key, item in load_model.items():
        if key.startswith('module'):
            new_key = '.'.join(key.split('.')[1:])
            load_state_dict[new_key] = item
        else:
            load_state_dict[key] = item
    save_dir = '/'.join(checkpoint_path.split('/')[:-1])
    if step == 'best':
        args.save_dir = os.path.join(save_dir, f"model_best")
    else:
        args.save_dir = os.path.join(save_dir, f"step_{args.step}")
    os.makedirs(args.save_dir, exist_ok=True)
    _net_builder = get_net_builder(args.net, args.net_from_name)
    net = _net_builder(num_classes=args.num_classes)
    if args.algorithm == 'openmatch':
        net = OpenMatchNet(net, args.num_classes)
    elif args.algorithm == 'iomatch':
        net = IOMatchNet(net, args.num_classes)
    else:
        raise NotImplementedError
    keys = net.load_state_dict(load_state_dict)
    print(f'Model at step {args.step} loaded!')
    if torch.cuda.is_available():
        net.cuda()
    net.eval()
    
    return net

In [8]:
def evaluate_open(net, dataset_dict, num_classes, extended_test=False):
    full_loader = DataLoader(dataset_dict['test']['full'], batch_size=256, drop_last=False, shuffle=False, num_workers=1)
    if extended_test:
        extended_loader = DataLoader(dataset_dict['test']['extended'], batch_size=1024, drop_last=False, shuffle=False, num_workers=4)

    total_num = 0.0
    y_true_list = []
    y_pred_closed_list = []
    y_pred_ova_list = []
    
    results = {}
    
    with torch.no_grad():
        for data in tqdm(full_loader):
            x = data['x_lb']
            y = data['y_lb']

            if isinstance(x, dict):
                x = {k: v.cuda() for k, v in x.items()}
            else:
                x = x.cuda()
            y = y.cuda()

            num_batch = y.shape[0]
            total_num += num_batch
            
            out = net(x)
            logits, logits_open = out['logits'], out['logits_open']    
            pred_closed = logits.data.max(1)[1]
            
            probs = F.softmax(logits, 1)
            probs_open = F.softmax(logits_open.view(logits_open.size(0), 2, -1), 1)
            tmp_range = torch.arange(0, logits_open.size(0)).long().cuda()
            unk_score = probs_open[tmp_range, 0, pred_closed]
            pred_open = pred_closed.clone()
            pred_open[unk_score > 0.5] = num_classes

            y_true_list.extend(y.cpu().tolist())
            y_pred_closed_list.extend(pred_closed.cpu().tolist())
            y_pred_ova_list.extend(pred_open.cpu().tolist())

    y_true = np.array(y_true_list)

    closed_mask = y_true < num_classes
    open_mask = y_true >= num_classes
    y_true[open_mask] = num_classes

    y_pred_closed = np.array(y_pred_closed_list)
    y_pred_ova = np.array(y_pred_ova_list)

    # Closed Accuracy on Closed Test Data
    y_true_closed = y_true[closed_mask]
    y_pred_closed = y_pred_closed[closed_mask]
    closed_acc = accuracy_score(y_true_closed, y_pred_closed)
    closed_cfmat = confusion_matrix(y_true_closed, y_pred_closed, normalize='true')
    results['c_acc_c_p'] = closed_acc
    results['c_cfmat_c_p'] = closed_cfmat  

    # Open Accuracy on Full Test Data
    open_acc = balanced_accuracy_score(y_true, y_pred_ova)
    open_cfmat = confusion_matrix(y_true, y_pred_ova)
    results['o_acc_f_hq'] = open_acc  
    results['o_cfmat_f_hq'] = open_cfmat
        
    if extended_test:
        with torch.no_grad():
            for data in tqdm(extended_loader):
                x = data['x_lb']
                y = data['y_lb']

                if isinstance(x, dict):
                    x = {k: v.cuda() for k, v in x.items()}
                else:
                    x = x.cuda()
                y = y.cuda()

                num_batch = y.shape[0]
                total_num += num_batch

                out = net(x)
                logits, logits_open = out['logits'], out['logits_open']    
                pred_closed = logits.data.max(1)[1]

                probs = F.softmax(logits, 1)
                probs_open = F.softmax(logits_open.view(logits_open.size(0), 2, -1), 1)
                tmp_range = torch.arange(0, logits_open.size(0)).long().cuda()
                unk_score = probs_open[tmp_range, 0, pred_closed]
                pred_open = pred_closed.clone()
                pred_open[unk_score > 0.5] = num_classes

                y_true_list.extend(y.cpu().tolist())
                y_pred_closed_list.extend(pred_closed.cpu().tolist())
                y_pred_ova_list.extend(pred_open.cpu().tolist())
            
        y_true = np.array(y_true_list)

        open_mask = y_true >= num_classes
        y_true[open_mask] = num_classes
        y_pred_ova = np.array(y_pred_ova_list)
        
        # Open Accuracy on Extended Test Data
        open_acc = balanced_accuracy_score(y_true, y_pred_ova)
        open_cfmat = confusion_matrix(y_true, y_pred_ova, normalize='true')
        results['o_acc_e_hq'] = open_acc  
        results['o_cfmat_e_hq'] = open_cfmat
    
    print(f"#############################################################\n"
              f" Closed Accuracy on Closed Test Data: {results['c_acc_c_p'] * 100:.2f}\n"
              f" Open Accuracy on Full Test Data:     {results['o_acc_f_hq'] * 100:.2f}\n"
              #f" Open Accuracy on Extended Test Data: {results['o_acc_e_hq'] * 100:.2f}\n"
              f"#############################################################\n"
        )
    
    return results

In [9]:
def evaluate_io(args, net, dataset_dict, extended_test=False):
    """
    evaluation function for open-set SSL setting
    """

    full_loader = DataLoader(dataset_dict['train_ulb'], batch_size=256, drop_last=False, shuffle=False, num_workers=1)
    if extended_test:
        extended_loader = DataLoader(dataset_dict['test']['extended'], batch_size=1024, drop_last=False, shuffle=False, num_workers=4)

    total_num = 0.0
    y_true_list = []
    p_list = []
    pred_p_list = []
    pred_hat_q_list = []
    pred_q_list = []
    pred_q_prob_list=[]
    pred_hat_p_list = []

    with torch.no_grad():
        for data in tqdm(full_loader):
            x = data['x_ulb_w']
            y = data['y_ulb']
            print('y',len(y),y)
            
            if isinstance(x, dict):
                x = {k: v.cuda() for k, v in x.items()}
            else:
                x = x.cuda()
            y = y.cuda()
            y_true_list.extend(y.cpu().tolist())

            num_batch = y.shape[0]
            total_num += num_batch

            outputs = net(x)
            logits = outputs['logits']
            logits_mb = outputs['logits_mb']
            logits_open = outputs['logits_open']

            # predictions p of closed-set classifier
            p = F.softmax(logits, 1)
            pred_p = p.data.max(1)[1]
            pred_p_list.extend(pred_p.cpu().tolist())

            # predictions hat_q from (closed-set + multi-binary) classifiers
            r = F.softmax(logits_mb.view(logits_mb.size(0), 2, -1), 1)
            tmp_range = torch.arange(0, logits_mb.size(0)).long().cuda()
            hat_q = torch.zeros((num_batch, args.num_classes + 1)).cuda()
            o_neg = r[tmp_range, 0, :]
            o_pos = r[tmp_range, 1, :]
            hat_q[:, :args.num_classes] = p * o_pos
            hat_q[:, args.num_classes] = torch.sum(p * o_neg, 1)
            pred_hat_q = hat_q.data.max(1)[1]
            pred_hat_q_list.extend(pred_hat_q.cpu().tolist())

            # predictions q of open-set classifier
            q = F.softmax(logits_open, 1)
            pred_q = q.data.max(1)[1]
            print('p',p[0],'o_pos',o_pos[0])
            print('q.data',q.data[0:10])
            print('hatq.data',hat_q.data[0:10])
            print('pred_q',pred_q)
            pred_q_prob = q.data
            pred_q_list.extend(pred_q.cpu().tolist())
            pred_q_prob_list.extend(pred_q_prob)

            # prediction hat_p of open-set classifier
            hat_p = q[:, :args.num_classes] / q[:, :args.num_classes].sum(1).unsqueeze(1)
            pred_hat_p = hat_p.data.max(1)[1]
            pred_hat_p_list.extend(pred_hat_p.cpu().tolist())

        y_true = np.array(y_true_list)
        #print('y_true',len(y_true),y_true)
        closed_mask = y_true < args.num_classes
        open_mask = y_true >= args.num_classes
        #y_true[open_mask] = args.num_classes
        #print('pred_q[censys_index]',pred_q[censys_index])

        pred_p = np.array(pred_p_list)
        pred_hat_p = np.array(pred_hat_p_list)
        pred_q = np.array(pred_q_list)
        pred_hat_q = np.array(pred_hat_q_list)
        
        # closed accuracy of p / hat_p on closed test data
        c_acc_c_p = accuracy_score(y_true[closed_mask], pred_p[closed_mask])
        c_acc_c_hp = accuracy_score(y_true[closed_mask], pred_hat_p[closed_mask])
        set_label =  ['adscore',
 'ahrefs',
 'arbor',
 'archive',
 'binaryedge',

 'criminalip',
 'cybergreen',

 'fofa',
 'internet_census',
 'internettl',
 'intrinsec',
 'ipip',
 'leakix',
 'onyphe',
 'quadmetrics',
 'quake',
 'rapid7',
 'rau',
 'shadowserver',
 'shodan',
 'stretchoid',
 'tum',
 'webRay',
 'x_threatbook',
 'zoomeye',
'driftnet(added in unknown)',
 'unknown',
 'censys(added in unknown)'                     ]
 
        c_cfmat_c_p = confusion_matrix(y_true[closed_mask], pred_p[closed_mask])
        
        
        
        print('y_true[closed_mask]',y_true[closed_mask])
        print('pred_hat_p[closed_mask]',pred_hat_p[closed_mask])
        c_cfmat_c_hp = confusion_matrix(y_true[closed_mask], pred_hat_p[closed_mask], normalize='true')
        np.set_printoptions(precision=3, suppress=True)

        # open accuracy of q / hat_q on full test data
        o_acc_f_q = balanced_accuracy_score(y_true, pred_q)
        o_acc_f_q_f = f1_score(y_true, pred_q,average='weighted')
        o_acc_f_hq = balanced_accuracy_score(y_true, pred_hat_q)
        o_acc_f_hq_f = f1_score(y_true, pred_hat_q,average='weighted')
        o_cfmat_f_q = confusion_matrix(y_true, pred_q)
        o_cfmat_f_hq = confusion_matrix(y_true, pred_hat_q)
        
   
        index_list=list(set(y_true)|set(pred_q))
        label_list=[]
        for item in index_list:
            label_list.append(set_label[item])

    
        y_true_arr = y_true.copy()
        
        
        value_to_delete = 26
        
        mask = y_true != value_to_delete
        filtered_y_true = y_true_arr[mask]
        filtered_pred_q = pred_q[mask]
        o_acc_f_q_filtered = balanced_accuracy_score(filtered_y_true, filtered_pred_q)
        y_true_arr[open_mask] = args.num_classes
        filtered_y_true_open_masked = y_true_arr[mask]
        o_acc_f_q_filtered_masked = balanced_accuracy_score(filtered_y_true_open_masked, filtered_pred_q)
        
        print('o_acc_f_q_filtered',o_acc_f_q_filtered) 
        print('o_acc_f_q_filtered_masked',o_acc_f_q_filtered_masked)

        o_acc_e_q = o_acc_e_hq = 0
        o_cfmat_e_q = None
        o_cfmat_e_hq = None

        if extended_test:
            unk_scores = []
            unk_scores_q = []
            for data in tqdm(extended_loader):
                x = data['x_lb']
                y = data['y_lb']

                if isinstance(x, dict):
                    x = {k: v.cuda() for k, v in x.items()}
                else:
                    x = x.cuda()
                y = y.cuda()
                y_true_list.extend(y.cpu().tolist())

                num_batch = y.shape[0]
                total_num += num_batch

                outputs = net(x)
                logits = outputs['logits']
                logits_mb = outputs['logits_mb']
                logits_open = outputs['logits_open']

                # predictions p of closed-set classifier
                p = F.softmax(logits, 1)
                pred_p = p.data.max(1)[1]
                pred_p_list.extend(pred_p.cpu().tolist())

                # predictions hat_q of (closed-set + multi-binary) classifiers
                r = F.softmax(logits_mb.view(logits_mb.size(0), 2, -1), 1)
                tmp_range = torch.arange(0, logits_mb.size(0)).long().cuda()
                hat_q = torch.zeros((num_batch, args.num_classes + 1)).cuda()
                o_neg = r[tmp_range, 0, :]
                o_pos = r[tmp_range, 1, :]
                unk_score = torch.sum(p * o_neg, 1)
                hat_q[:, :args.num_classes] = p * o_pos
                hat_q[:, args.num_classes] = torch.sum(p * o_neg, 1)
                pred_hat_q = hat_q.data.max(1)[1]
                pred_hat_q_list.extend(pred_hat_q.cpu().tolist())

                # predictions q of open-set classifier
                q = F.softmax(logits_open, 1)
                pred_q = q.data.max(1)[1]
                pred_q_list.extend(pred_q.cpu().tolist())

                # prediction hat_p of open-set classifier
                hat_p = q[:, :args.num_classes] / q[:, :args.num_classes].sum(1).unsqueeze(1)
                pred_hat_p = hat_p.data.max(1)[1]
                pred_hat_p_list.extend(pred_hat_p.cpu().tolist())

            y_true = np.array(y_true_list)
            open_mask = y_true >= args.num_classes
            y_true[open_mask] = args.num_classes

            pred_q = np.array(pred_q_list)
            pred_hat_q = np.array(pred_hat_q_list)

            # open accuracy of q / hat_q on extended test data
            o_acc_e_q = balanced_accuracy_score(y_true, pred_q)
            o_acc_e_hq = balanced_accuracy_score(y_true, pred_hat_q)
            o_cfmat_e_q = confusion_matrix(y_true, pred_q, normalize='true')
            o_cfmat_e_hq = confusion_matrix(y_true, pred_hat_q, normalize='true')

        eval_dict = {'c_acc_c_p': c_acc_c_p, 'c_acc_c_hp': c_acc_c_hp,
                     'o_acc_f_q': o_acc_f_q, 'o_acc_f_hq': o_acc_f_hq,
                     'o_acc_e_q': o_acc_e_q, 'o_acc_e_hq': o_acc_e_hq,
                     'c_cfmat_c_p': c_cfmat_c_p, 'c_cfmat_c_hp': c_cfmat_c_hp,
                     'o_cfmat_f_q': o_cfmat_f_q, 'o_cfmat_f_hq': o_cfmat_f_hq,
                     'o_cfmat_e_q': o_cfmat_e_q, 'o_cfmat_e_hq': o_cfmat_e_hq,
                     'pred_q_prob_list':pred_q_prob_list,
                     'o_acc_f_q_f':o_acc_f_q_f,
                     'label_list':label_list,
                     'o_acc_f_q_filtered':o_acc_f_q_filtered,
                     'pred_q':pred_q,
                     'pred_p':pred_p,
                     'y_true':y_true
                    }
        print('fscore',o_acc_f_q_f,'f_hq',o_acc_f_hq_f)
        print(f"#############################################################\n"
              f" Closed Accuracy on Closed Test Data (p / hp): {c_acc_c_p * 100:.2f} / {c_acc_c_hp * 100:.2f}\n"
              f" Open Accuracy on Full Test Data (q / hq):     {o_acc_f_q * 100:.2f} / {o_acc_f_hq * 100:.2f}\n"
              f" Open Accuracy on Extended Test Data (q / hq): {o_acc_e_q * 100:.2f} / {o_acc_e_hq * 100:.2f}\n"
              f"#############################################################\n"
            )

        return eval_dict

In [15]:
import numpy as np
import pandas as pd
import pickle

dataset_name='SelfDeploy24'


for protocol in ['http','tls','dns']:
    if(dataset_name=='SelfDeploy24'):
        dataset='selfdeploy_merge_month_2_4_2weektest'
        name='merge_month_2_4_2week'
        train_ensemble_dir='train_ensemble_batch64'
        earliest_time = pd.to_datetime('2024-01-31 16:00:00')
        latest_time = pd.to_datetime('2024-04-12 16:00:00')
    if(dataset_name=='SelfDeploy25'):
        dataset='selfdeploy_24_25_fulldata_2weektest'
        name='selfdeploy_24_25_2week'
        train_ensemble_dir='train_ensemble_selfdeploy25_batch64'
        earliest_time = pd.to_datetime('2024-12-10 16:00:00') 
        latest_time = pd.to_datetime('2025-02-25 16:00:00')

    config_name='iomatch_exp_'+dataset+'_'+protocol+'.yaml'#iomatch_cifar10_14000_28class_exp_month2_4_http_perservice_40dayscompletetrain_second
    args = parser.parse_args(args=['--c', 'config/openset_cv/iomatch/'+config_name])#test5#iomatch_cifar10_150_eval_25_n#iomatch_cifar10_14000_15class_comp_1
    over_write_args_from_file(args, args.c)#iomatch_cifar10_300_13_2class_comp_1#iomatch_cifar10_14000_13_2class_comp_1
    args.data_dir = 'data'
    total_epoch=250
    batch_size=64
    args.bsz = batch_size
    args.load_path = "./saved_models/openset_cv/iomatch_"+dataset+"_"+protocol+"_extra_ep"+str(total_epoch)+"_bs"+str(batch_size)+"/latest_model.pth"
    
    dataset_dict = get_dataset(args, args.algorithm, args.dataset, args.num_labels, args.num_classes, args.data_dir, eval_open=False)
    best_net = load_model_at('1999')#2199
    eval_dict = evaluate_io(args, best_net, dataset_dict)

    
    if('http' in config_name):
        proto='http'
    if('tls' in config_name):
        proto='tls'
    if('dns' in config_name):
        proto='dns'
    if('ssh' in config_name):
        proto='ssh'
    if('rdp' in config_name):
        proto='rdp'
        
    
    mode='train'
    npy_path1 = '../dataset/npy/ip_array_'+proto+'_'+name+'_'+mode+'.npy'
    #npy_path1 = 'ip_array_'+proto+'_'+name+'_'+mode+'.npy'
    npy_path2 = '../dataset/npy/timestamp_'+proto+'_'+name+'_'+mode+'.npy'
    #dataset = NumpyDataset(npy_path1,npy_path2)
    
    loaded_array = np.load(npy_path1,allow_pickle=True)
    loaded_array
    len(loaded_array)
    loaded_timestamp = np.load(npy_path2,allow_pickle=True)
    
    
    pred = eval_dict['pred_q']
    y_true = eval_dict['y_true']
    df = pd.DataFrame({'prediction': pred, 'ip': loaded_array,'groundtruth':y_true,'timestamp':loaded_timestamp})
    
    print(df)
    
    # A list of honeypot addresses to remove
    ips_to_remove = ['192.168.1.0', '192.168.2.0','192.168.3.0','192.168.4.0','192.168.5.0']
    
    # Removes the row with the specified IP address
    df = df[~df['ip'].isin(ips_to_remove)]
            
    
    df.to_csv('result_'+proto+'_'+name+'_'+mode+'.csv', index=False)
    
    
    import pandas as pd
    
    
    for day_interval in [7]:
    
        df = pd.read_csv('result_'+proto+'_'+name+'_'+mode+'.csv')
        
        
        df['timestamp'] = pd.to_datetime(df['timestamp'])
        
        
        print(f"Earliest time: {earliest_time}")
        print(f"Latest time: {latest_time}")
        
        
        time_range = latest_time - earliest_time
        
        
        num_segments = (time_range.days) // day_interval 
        
        
        segment_dfs = []
        
        
        for i in range(num_segments):
            start_time = earliest_time + pd.Timedelta(days=i * day_interval)
            end_time = earliest_time + pd.Timedelta(days=(i + 1) * day_interval)
            print(start_time,end_time)
            
            segment_df = df[(df['timestamp'] >= start_time) & (df['timestamp'] < end_time)]
            
            
            segment_dfs.append(segment_df)
        
        
        for i, segment_df in enumerate(segment_dfs):
            start_time = earliest_time + pd.Timedelta(days=i * day_interval)
            end_time = earliest_time + pd.Timedelta(days=(i + 1) * day_interval)
            print(f"Segment {i+1} ({start_time} to {end_time}):")
            print(segment_df)
            print("\n")
            
        file_dir=train_ensemble_dir+'/result_'+str(day_interval)+'day_segment_'+proto+'/'
        import os
        
        
        directory = file_dir
        
        
        if not os.path.exists(directory):
            os.mkdir(directory)
            print(f"Directory '{directory}' created.")
        else:
            print(f"Directory '{directory}' already exists.")
            
        
        for i, segment_df in enumerate(segment_dfs):
            segment_df.to_csv(train_ensemble_dir+'/result_'+str(day_interval)+'day_segment_'+proto+f'/segment_{i+1}.csv', index=False)
        
        #The most common of the predicted label is calculated for each ip
        def protocol_level_prediciton(ip_train2,y_train2,train2_label_dic_common):
            a=pd.DataFrame()
            a['ip']=ip_train2
            a['label']=y_train2
            a['pred_label']=None
            a=a.drop_duplicates()
            for index, row in a.iterrows():
                a.loc[index,'pred_label']=train2_label_dic_common[row['ip']]
            return a
        
        import os
        import pandas as pd
        
        
        folder_path = train_ensemble_dir+'/result_'+str(day_interval)+'day_segment_'+proto
        
        
        files = [f for f in os.listdir(folder_path) if f.startswith('seg') and f.endswith('.csv')]
        
        
        dfs = []
        
        for file in files:
            file_path = os.path.join(folder_path, file)
            df = pd.read_csv(file_path)
            dfs.append(df)
        
        dfs_merge_protocol_level={}
        
        for i, df in enumerate(dfs):
            print(f"DataFrame from {files[i]}:")

        
            test_http_label_dic = {}
            for value, group_df in df.groupby('ip'):
                test_http_label_dic[value] = group_df
                
            test_http_label_dic_common={}
            for j in test_http_label_dic.keys():
                test_http_label_dic_common[j]=test_http_label_dic[j]['prediction'].mode().values[0]
            
            df_ip_http_test=protocol_level_prediciton(df['ip'],df['groundtruth'],test_http_label_dic_common)
            df_ip_http_test=df_ip_http_test.rename(columns={'pred_label':'pred_label_'+proto})
            dfs_merge_protocol_level[files[i]]=df_ip_http_test

        with open('proto_level_'+str(day_interval)+'daysegment_pred_'+proto+'_'+name+'_'+mode+'.pkl', 'wb') as f:
            pickle.dump(dfs_merge_protocol_level, f)
         
        dfs_merge_protocol_level

dataset_name merge_month_2_4_2week
data_dir data/cifar10
num_labels 2500 num_classes 25 this mode use all samples
dataset_data (4214, 18, 18, 3)
dataset_data (65480, 18, 18, 3)
len_seen_indices 1029
dataset_data (1029, 18, 18, 3)
dataset_data (13779, 18, 18, 3)
lb_dset <semilearn.datasets.cv_datasets.datasetbase.BasicDataset object at 0x7f9e1e0f6220> lb_dset 4214
./saved_models/openset_cv/iomatch_selfdeploy_merge_month_2_4_2weektest_http_extra_ep250_bs64/1999_model.pth


  checkpoint = torch.load(checkpoint_path)


Model at step 1999 loaded!


  0%|▏                                          | 1/256 [00:00<00:35,  7.10it/s]

y 256 tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
p tensor([0.0051, 0.0048, 0.0043, 0.0065, 0.2541, 0.0046, 0.0066, 0.0098, 0.0096,
        0.0087, 0.0058, 0.0064, 0.0088, 0.0034, 0.0066,

  2%|█                                          | 6/256 [00:00<00:09, 25.89it/s]

p tensor([0.0051, 0.0048, 0.0043, 0.0065, 0.2529, 0.0046, 0.0066, 0.0097, 0.0095,
        0.0086, 0.0058, 0.0063, 0.0088, 0.0034, 0.0066, 0.0071, 0.0051, 0.0054,
        0.4939, 0.1203, 0.0051, 0.0036, 0.0041, 0.0049, 0.0070],
       device='cuda:0') o_pos tensor([0.2231, 0.2632, 0.2950, 0.3123, 0.4149, 0.1509, 0.2904, 0.3440, 0.1731,
        0.3366, 0.3345, 0.2126, 0.3607, 0.1497, 0.2997, 0.3103, 0.2757, 0.2546,
        0.5761, 0.3979, 0.2016, 0.2752, 0.3045, 0.2671, 0.2974],
       device='cuda:0')
q.data tensor([[0.0090, 0.0094, 0.0088, 0.0100, 0.1228, 0.0087, 0.0093, 0.0140, 0.0102,
         0.0139, 0.0101, 0.0149, 0.0095, 0.0100, 0.0099, 0.0105, 0.0092, 0.0080,
         0.2296, 0.0542, 0.0114, 0.0081, 0.0095, 0.0097, 0.0107, 0.3688],
        [0.0090, 0.0094, 0.0088, 0.0100, 0.1228, 0.0087, 0.0093, 0.0140, 0.0102,
         0.0139, 0.0101, 0.0149, 0.0095, 0.0100, 0.0099, 0.0105, 0.0092, 0.0080,
         0.2296, 0.0542, 0.0114, 0.0081, 0.0095, 0.0097, 0.0107, 0.3688],
        [0.0090

  4%|█▊                                        | 11/256 [00:00<00:07, 32.73it/s]

p tensor([0.0051, 0.0048, 0.0043, 0.0065, 0.2529, 0.0046, 0.0066, 0.0097, 0.0095,
        0.0086, 0.0058, 0.0063, 0.0088, 0.0034, 0.0066, 0.0071, 0.0051, 0.0054,
        0.4939, 0.1203, 0.0051, 0.0036, 0.0041, 0.0049, 0.0070],
       device='cuda:0') o_pos tensor([0.2231, 0.2632, 0.2950, 0.3123, 0.4149, 0.1509, 0.2904, 0.3440, 0.1731,
        0.3366, 0.3345, 0.2126, 0.3607, 0.1497, 0.2997, 0.3103, 0.2757, 0.2546,
        0.5761, 0.3979, 0.2016, 0.2752, 0.3045, 0.2671, 0.2974],
       device='cuda:0')
q.data tensor([[0.0090, 0.0094, 0.0088, 0.0100, 0.1228, 0.0087, 0.0093, 0.0140, 0.0102,
         0.0139, 0.0101, 0.0149, 0.0095, 0.0100, 0.0099, 0.0105, 0.0092, 0.0080,
         0.2296, 0.0542, 0.0114, 0.0081, 0.0095, 0.0097, 0.0107, 0.3688],
        [0.0090, 0.0094, 0.0088, 0.0100, 0.1228, 0.0087, 0.0093, 0.0140, 0.0102,
         0.0139, 0.0101, 0.0149, 0.0095, 0.0100, 0.0099, 0.0105, 0.0092, 0.0080,
         0.2296, 0.0542, 0.0114, 0.0081, 0.0095, 0.0097, 0.0107, 0.3688],
        [0.0090

  6%|██▋                                       | 16/256 [00:00<00:06, 36.37it/s]

tensor([[1.1372e-03, 1.2624e-03, 1.2556e-03, 2.0149e-03, 1.0494e-01, 6.9522e-04,
         1.9110e-03, 3.3327e-03, 1.6408e-03, 2.9063e-03, 1.9479e-03, 1.3415e-03,
         3.1589e-03, 5.1096e-04, 1.9765e-03, 2.1955e-03, 1.4002e-03, 1.3700e-03,
         2.8454e-01, 4.7885e-02, 1.0244e-03, 9.9325e-04, 1.2631e-03, 1.3217e-03,
         2.0847e-03, 5.2589e-01],
        [1.1372e-03, 1.2624e-03, 1.2556e-03, 2.0149e-03, 1.0494e-01, 6.9522e-04,
         1.9110e-03, 3.3327e-03, 1.6408e-03, 2.9063e-03, 1.9479e-03, 1.3415e-03,
         3.1589e-03, 5.1096e-04, 1.9765e-03, 2.1955e-03, 1.4002e-03, 1.3700e-03,
         2.8454e-01, 4.7885e-02, 1.0244e-03, 9.9325e-04, 1.2631e-03, 1.3217e-03,
         2.0847e-03, 5.2589e-01],
        [1.1372e-03, 1.2624e-03, 1.2556e-03, 2.0149e-03, 1.0494e-01, 6.9522e-04,
         1.9110e-03, 3.3327e-03, 1.6408e-03, 2.9063e-03, 1.9479e-03, 1.3415e-03,
         3.1589e-03, 5.1096e-04, 1.9765e-03, 2.1955e-03, 1.4002e-03, 1.3700e-03,
         2.8454e-01, 4.7885e-02, 1.0244e-

  8%|███▍                                      | 21/256 [00:00<00:06, 38.31it/s]

tensor([0.0607, 0.0567, 0.0818, 0.0629, 0.9366, 0.0546, 0.0818, 0.0419, 0.0381,
        0.0298, 0.0579, 0.0933, 0.0537, 0.0493, 0.0726, 0.0588, 0.0758, 0.0690,
        0.0601, 0.0401, 0.0755, 0.0852, 0.0701, 0.0760, 0.0656],
       device='cuda:0')
q.data tensor([[1.3361e-04, 1.7261e-04, 1.5549e-04, 3.0863e-04, 8.2351e-01, 4.4711e-04,
         1.6276e-04, 5.0516e-04, 1.1148e-03, 1.3962e-04, 3.6488e-04, 3.3461e-03,
         6.4020e-04, 8.5893e-04, 1.8914e-04, 2.6075e-04, 2.9259e-04, 1.5211e-04,
         5.1891e-03, 2.5060e-03, 1.0989e-03, 2.0505e-04, 1.2473e-04, 1.7968e-04,
         6.9473e-04, 1.5725e-01],
        [9.0198e-03, 9.4023e-03, 8.7932e-03, 1.0013e-02, 1.2291e-01, 8.7559e-03,
         9.2851e-03, 1.3990e-02, 1.0279e-02, 1.3907e-02, 1.0120e-02, 1.4942e-02,
         9.5558e-03, 1.0037e-02, 9.9409e-03, 1.0487e-02, 9.1917e-03, 7.9969e-03,
         2.2893e-01, 5.4303e-02, 1.1457e-02, 8.0892e-03, 9.5230e-03, 9.7618e-03,
         1.0788e-02, 3.6852e-01],
        [5.4790e-03, 5.1092e

 10%|████▎                                     | 26/256 [00:00<00:05, 39.36it/s]

p tensor([1.7114e-04, 6.7037e-05, 1.4528e-04, 2.2718e-04, 4.1260e-03, 7.1040e-04,
        1.1900e-04, 8.6280e-04, 9.7544e-01, 1.5174e-04, 3.1030e-04, 9.2779e-04,
        2.9656e-04, 8.3572e-05, 1.3942e-04, 4.5277e-04, 3.0255e-04, 1.5700e-04,
        9.0450e-03, 3.1386e-03, 2.4351e-03, 1.1226e-04, 2.2172e-04, 1.4359e-04,
        2.0822e-04], device='cuda:0') o_pos tensor([1.1351e-02, 5.1119e-03, 3.6525e-03, 1.8836e-03, 8.8453e-03, 5.7804e-03,
        6.7049e-03, 1.8364e-02, 9.8948e-01, 1.6156e-04, 1.1111e-03, 8.7195e-03,
        1.8351e-03, 1.5321e-04, 5.1685e-03, 3.0327e-03, 3.5122e-03, 7.3526e-03,
        1.5033e-02, 1.0201e-02, 2.0681e-02, 4.2530e-03, 1.2879e-03, 4.5008e-03,
        9.6090e-03], device='cuda:0')
q.data tensor([[3.7454e-06, 9.4122e-06, 8.2215e-06, 1.7228e-05, 5.0525e-04, 1.6566e-04,
         8.2737e-06, 2.1149e-04, 9.2368e-01, 4.9731e-06, 2.4866e-05, 5.4213e-04,
         3.3595e-05, 7.3744e-06, 1.1687e-05, 1.2702e-05, 1.1823e-05, 1.3445e-05,
         1.1801e-03, 6.374

 12%|█████                                     | 31/256 [00:00<00:05, 40.05it/s]

tensor([[9.1714e-06, 8.3296e-06, 4.1940e-06, 1.1663e-05, 8.8423e-04, 4.7068e-05,
         5.8784e-06, 8.2805e-06, 5.5040e-06, 1.9470e-05, 9.0315e-06, 4.5894e-05,
         4.2661e-06, 4.7732e-07, 5.1824e-06, 6.8126e-06, 6.7192e-06, 1.0957e-05,
         9.2427e-01, 2.7140e-04, 5.2043e-05, 1.0695e-05, 4.7566e-06, 5.5636e-06,
         1.8671e-05, 7.4277e-02],
        [5.8523e-06, 7.1921e-06, 3.0676e-06, 7.8944e-06, 4.5683e-04, 9.0648e-05,
         3.5278e-06, 3.0876e-06, 8.2127e-06, 1.3067e-05, 7.7163e-06, 2.3246e-05,
         2.4063e-06, 2.9435e-07, 2.8334e-06, 5.2159e-06, 4.0217e-06, 7.6519e-06,
         9.4532e-01, 1.1853e-04, 4.0210e-05, 8.4134e-06, 3.9175e-06, 4.4157e-06,
         1.2850e-05, 5.3837e-02],
        [1.1372e-03, 1.2624e-03, 1.2556e-03, 2.0149e-03, 1.0494e-01, 6.9522e-04,
         1.9110e-03, 3.3327e-03, 1.6408e-03, 2.9063e-03, 1.9479e-03, 1.3415e-03,
         3.1589e-03, 5.1096e-04, 1.9765e-03, 2.1955e-03, 1.4002e-03, 1.3700e-03,
         2.8454e-01, 4.7885e-02, 1.0244e-

 14%|█████▉                                    | 36/256 [00:00<00:05, 40.66it/s]

tensor([[4.4847e-06, 2.5591e-06, 1.5514e-05, 3.3043e-06, 1.2782e-05, 4.1797e-06,
         8.0994e-06, 4.4341e-05, 1.9803e-05, 1.5956e-05, 6.1996e-07, 4.8212e-06,
         5.1985e-06, 1.7322e-05, 5.9096e-06, 9.4321e-06, 6.6565e-06, 1.0605e-05,
         1.1870e-05, 5.1918e-06, 9.6958e-01, 5.7457e-06, 1.0537e-05, 1.2236e-05,
         9.1576e-06, 3.0173e-02],
        [4.4847e-06, 2.5591e-06, 1.5514e-05, 3.3043e-06, 1.2782e-05, 4.1797e-06,
         8.0994e-06, 4.4341e-05, 1.9803e-05, 1.5956e-05, 6.1996e-07, 4.8212e-06,
         5.1985e-06, 1.7322e-05, 5.9096e-06, 9.4321e-06, 6.6565e-06, 1.0605e-05,
         1.1870e-05, 5.1918e-06, 9.6958e-01, 5.7457e-06, 1.0537e-05, 1.2236e-05,
         9.1576e-06, 3.0173e-02],
        [2.2998e-06, 1.2747e-06, 7.3878e-06, 1.7135e-06, 5.1018e-06, 2.1787e-06,
         4.2547e-06, 3.5815e-05, 8.7825e-06, 1.2544e-05, 3.2664e-07, 5.3931e-06,
         2.9724e-06, 1.0016e-05, 3.0992e-06, 5.0141e-06, 3.8633e-06, 5.7651e-06,
         5.6361e-06, 2.3446e-06, 9.7586e-

 16%|██████▋                                   | 41/256 [00:01<00:05, 41.08it/s]

p tensor([0.0016, 0.0035, 0.0038, 0.0019, 0.0123, 0.0089, 0.0049, 0.0241, 0.0008,
        0.7549, 0.0095, 0.0245, 0.0060, 0.0097, 0.0030, 0.0052, 0.0033, 0.0035,
        0.0552, 0.0201, 0.0259, 0.0033, 0.0039, 0.0051, 0.0050],
       device='cuda:0') o_pos tensor([0.0512, 0.1161, 0.1077, 0.1374, 0.0174, 0.0668, 0.1187, 0.1058, 0.0011,
        0.9043, 0.0943, 0.1585, 0.1074, 0.0876, 0.1088, 0.1014, 0.0868, 0.1066,
        0.1317, 0.0846, 0.1557, 0.1287, 0.1326, 0.1138, 0.0856],
       device='cuda:0')
q.data tensor([[2.3259e-03, 1.6836e-03, 2.1532e-03, 1.2122e-03, 4.7344e-03, 6.3559e-03,
         1.2027e-03, 1.8796e-02, 3.5245e-04, 5.2465e-01, 6.4730e-03, 2.1202e-02,
         5.6205e-03, 4.3822e-03, 1.4500e-03, 7.6163e-03, 8.0750e-04, 1.3192e-03,
         2.9086e-02, 5.9374e-03, 2.2166e-02, 1.1584e-03, 2.3609e-03, 1.7665e-03,
         3.2576e-03, 3.2193e-01],
        [8.9784e-03, 9.3776e-03, 8.7656e-03, 9.9832e-03, 1.2323e-01, 8.7362e-03,
         9.2458e-03, 1.4007e-02, 1.0266e-02, 1.3

 18%|███████▌                                  | 46/256 [00:01<00:05, 41.08it/s]

tensor([[2.9151e-03, 2.8212e-03, 4.7529e-03, 4.5770e-03, 4.1164e-02, 3.4602e-03,
         3.0206e-03, 7.3897e-03, 2.0143e-02, 3.5102e-03, 4.3231e-03, 2.5970e-03,
         1.7363e-02, 1.0458e-02, 4.3600e-03, 9.5917e-03, 4.0150e-03, 3.2103e-03,
         1.8447e-02, 1.9097e-02, 9.7594e-03, 2.8741e-03, 4.5486e-03, 3.8231e-03,
         2.9260e-03, 7.8885e-01],
        [2.0794e-05, 1.3190e-05, 5.4245e-05, 1.8296e-05, 2.2066e-05, 2.0084e-05,
         3.0206e-05, 8.7335e-05, 2.7009e-05, 1.2758e-04, 3.9249e-06, 4.4603e-05,
         2.4016e-05, 5.3062e-05, 2.3716e-05, 3.6508e-05, 3.0959e-05, 4.0463e-05,
         1.5416e-04, 2.3655e-05, 9.3602e-01, 2.6588e-05, 4.9394e-05, 4.9424e-05,
         4.6480e-05, 6.2952e-02],
        [7.9091e-04, 7.1460e-04, 3.9771e-04, 7.5068e-04, 2.0905e-02, 6.2880e-04,
         6.4312e-04, 3.7943e-03, 5.5258e-04, 4.8093e-03, 6.8343e-04, 1.5211e-01,
         2.6364e-03, 2.3145e-04, 5.1198e-04, 5.5397e-04, 7.4156e-04, 8.4312e-04,
         2.6940e-01, 5.5914e-03, 3.1397e-

 20%|████████▎                                 | 51/256 [00:01<00:04, 41.16it/s]

tensor([[2.0794e-05, 1.3190e-05, 5.4245e-05, 1.8296e-05, 2.2066e-05, 2.0084e-05,
         3.0206e-05, 8.7335e-05, 2.7009e-05, 1.2758e-04, 3.9249e-06, 4.4603e-05,
         2.4016e-05, 5.3062e-05, 2.3716e-05, 3.6508e-05, 3.0959e-05, 4.0463e-05,
         1.5416e-04, 2.3655e-05, 9.3602e-01, 2.6588e-05, 4.9394e-05, 4.9424e-05,
         4.6480e-05, 6.2952e-02],
        [3.8019e-04, 4.7509e-04, 8.0746e-04, 1.0523e-03, 1.3146e-03, 1.5409e-04,
         9.3109e-04, 3.5601e-01, 3.4375e-04, 3.0811e-03, 1.5054e-04, 1.3893e-03,
         2.4973e-03, 1.7468e-03, 6.5089e-04, 1.0049e-03, 9.9550e-04, 7.5879e-04,
         1.6416e-03, 1.8993e-03, 3.9840e-02, 6.8849e-04, 5.0045e-04, 8.5716e-04,
         3.2558e-04, 5.8051e-01],
        [3.8019e-04, 4.7509e-04, 8.0746e-04, 1.0523e-03, 1.3146e-03, 1.5409e-04,
         9.3109e-04, 3.5601e-01, 3.4375e-04, 3.0811e-03, 1.5054e-04, 1.3893e-03,
         2.4973e-03, 1.7468e-03, 6.5089e-04, 1.0049e-03, 9.9550e-04, 7.5879e-04,
         1.6416e-03, 1.8993e-03, 3.9840e-

 22%|█████████▏                                | 56/256 [00:01<00:04, 41.34it/s]

p tensor([1.1814e-03, 1.4612e-03, 1.9849e-03, 1.4791e-03, 1.6868e-03, 1.3003e-03,
        1.2793e-03, 3.4260e-03, 2.8517e-03, 4.0201e-03, 1.2557e-03, 2.9791e-03,
        3.1346e-03, 3.2884e-03, 8.9587e-04, 1.4512e-03, 1.1340e-03, 1.6600e-03,
        5.8489e-03, 2.5179e-03, 9.4785e-01, 1.4157e-03, 1.8009e-03, 1.8292e-03,
        2.2647e-03], device='cuda:0') o_pos tensor([0.0176, 0.0090, 0.0273, 0.0124, 0.0131, 0.0154, 0.0236, 0.0255, 0.0095,
        0.0317, 0.0031, 0.0150, 0.0077, 0.0161, 0.0265, 0.0252, 0.0273, 0.0244,
        0.0264, 0.0094, 0.9875, 0.0188, 0.0274, 0.0270, 0.0205],
       device='cuda:0')
q.data tensor([[1.3108e-04, 1.7887e-04, 3.9990e-05, 1.1504e-04, 8.8646e-04, 3.7413e-04,
         5.6365e-05, 1.5851e-03, 1.3142e-03, 6.3538e-04, 1.5399e-04, 9.5981e-04,
         5.5395e-04, 3.9093e-04, 1.1399e-04, 1.8734e-04, 8.4584e-05, 7.5403e-05,
         1.8739e-03, 1.7195e-04, 8.8241e-01, 4.8661e-05, 1.2382e-04, 7.1357e-05,
         3.9742e-04, 1.0706e-01],
        [1.3108e-04,

 24%|██████████                                | 61/256 [00:01<00:04, 41.62it/s]

tensor([0.0132, 0.0141, 0.0126, 0.0130, 0.0756, 0.0302, 0.0111, 0.0322, 0.0437,
        0.0370, 0.0177, 0.1944, 0.0282, 0.0195, 0.0091, 0.0174, 0.0112, 0.0154,
        0.2025, 0.0444, 0.1005, 0.0098, 0.0138, 0.0137, 0.0196],
       device='cuda:0') o_pos tensor([0.1154, 0.0823, 0.0711, 0.0941, 0.0800, 0.0729, 0.0893, 0.1495, 0.0595,
        0.1470, 0.0914, 0.5913, 0.1463, 0.0448, 0.0783, 0.0781, 0.1365, 0.1030,
        0.3558, 0.0620, 0.1360, 0.0950, 0.0706, 0.1026, 0.1155],
       device='cuda:0')
q.data tensor([[0.0053, 0.0048, 0.0034, 0.0043, 0.0287, 0.0083, 0.0042, 0.0190, 0.0127,
         0.0195, 0.0071, 0.1602, 0.0097, 0.0049, 0.0036, 0.0047, 0.0030, 0.0028,
         0.1147, 0.0092, 0.0303, 0.0029, 0.0056, 0.0039, 0.0088, 0.5184],
        [0.0046, 0.0067, 0.0035, 0.0068, 0.0373, 0.0082, 0.0044, 0.0544, 0.0374,
         0.0125, 0.0064, 0.0229, 0.0148, 0.0099, 0.0056, 0.0070, 0.0055, 0.0038,
         0.0464, 0.0187, 0.0983, 0.0039, 0.0061, 0.0057, 0.0120, 0.5570],
        [0.0046, 

 26%|██████████▊                               | 66/256 [00:01<00:04, 41.47it/s]

tensor([[1.0330e-03, 7.7418e-04, 2.0114e-03, 1.0603e-03, 1.3361e-02, 3.6223e-03,
         1.2296e-03, 6.8367e-04, 3.4176e-02, 8.8059e-04, 3.8301e-02, 2.7240e-04,
         2.7677e-03, 4.9965e-04, 1.1350e-03, 4.8137e-03, 8.6707e-04, 1.3309e-03,
         2.0300e-02, 1.6225e-02, 1.5336e-02, 1.5609e-03, 9.5197e-04, 1.0942e-03,
         1.8252e-03, 8.3389e-01],
        [2.0794e-05, 1.3190e-05, 5.4245e-05, 1.8296e-05, 2.2066e-05, 2.0084e-05,
         3.0206e-05, 8.7335e-05, 2.7009e-05, 1.2758e-04, 3.9249e-06, 4.4603e-05,
         2.4016e-05, 5.3062e-05, 2.3716e-05, 3.6508e-05, 3.0959e-05, 4.0463e-05,
         1.5416e-04, 2.3655e-05, 9.3602e-01, 2.6588e-05, 4.9394e-05, 4.9424e-05,
         4.6480e-05, 6.2952e-02],
        [1.1372e-03, 1.2624e-03, 1.2556e-03, 2.0149e-03, 1.0494e-01, 6.9522e-04,
         1.9110e-03, 3.3327e-03, 1.6408e-03, 2.9063e-03, 1.9479e-03, 1.3415e-03,
         3.1589e-03, 5.1096e-04, 1.9765e-03, 2.1955e-03, 1.4002e-03, 1.3700e-03,
         2.8454e-01, 4.7885e-02, 1.0244e-

 28%|███████████▋                              | 71/256 [00:01<00:04, 41.44it/s]

tensor([[1.1372e-03, 1.2624e-03, 1.2556e-03, 2.0149e-03, 1.0494e-01, 6.9522e-04,
         1.9110e-03, 3.3327e-03, 1.6408e-03, 2.9063e-03, 1.9479e-03, 1.3415e-03,
         3.1589e-03, 5.1096e-04, 1.9765e-03, 2.1955e-03, 1.4002e-03, 1.3700e-03,
         2.8454e-01, 4.7885e-02, 1.0244e-03, 9.9325e-04, 1.2631e-03, 1.3217e-03,
         2.0847e-03, 5.2589e-01],
        [4.8165e-04, 3.2606e-04, 3.0167e-04, 3.6319e-04, 6.7900e-03, 1.4634e-04,
         4.0035e-04, 6.6613e-02, 7.9582e-04, 2.0205e-03, 2.1111e-04, 2.5436e-01,
         2.7472e-03, 3.4415e-04, 2.8952e-04, 3.2900e-04, 6.3166e-04, 5.2990e-04,
         2.4939e-03, 6.6529e-04, 4.7760e-03, 2.9223e-04, 2.0742e-04, 4.8745e-04,
         8.4334e-04, 6.5256e-01],
        [4.8165e-04, 3.2606e-04, 3.0167e-04, 3.6319e-04, 6.7900e-03, 1.4634e-04,
         4.0035e-04, 6.6613e-02, 7.9582e-04, 2.0205e-03, 2.1111e-04, 2.5436e-01,
         2.7472e-03, 3.4415e-04, 2.8952e-04, 3.2900e-04, 6.3166e-04, 5.2990e-04,
         2.4939e-03, 6.6529e-04, 4.7760e-

 32%|█████████████▎                            | 81/256 [00:02<00:07, 24.73it/s]

tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 

 36%|██████████████▉                           | 91/256 [00:02<00:05, 30.72it/s]

p tensor([0.0122, 0.0101, 0.0136, 0.0150, 0.1264, 0.0398, 0.0118, 0.0219, 0.2087,
        0.0266, 0.0299, 0.0280, 0.0233, 0.0169, 0.0132, 0.0237, 0.0134, 0.0136,
        0.1726, 0.0726, 0.0536, 0.0107, 0.0134, 0.0125, 0.0163],
       device='cuda:0') o_pos tensor([0.1146, 0.1068, 0.1122, 0.0954, 0.1366, 0.1246, 0.1293, 0.1121, 0.4060,
        0.0704, 0.1128, 0.0955, 0.0981, 0.0336, 0.1122, 0.1127, 0.1004, 0.1189,
        0.2535, 0.1515, 0.2027, 0.1111, 0.0916, 0.1177, 0.1185],
       device='cuda:0')
q.data tensor([[4.2560e-03, 5.2174e-03, 4.9181e-03, 5.6069e-03, 4.8821e-02, 1.7868e-02,
         4.9797e-03, 1.1325e-02, 1.0022e-01, 8.4522e-03, 1.1570e-02, 1.8762e-02,
         9.9349e-03, 6.2390e-03, 6.0905e-03, 7.0729e-03, 5.2751e-03, 6.2304e-03,
         9.6876e-02, 3.7689e-02, 3.1498e-02, 4.1943e-03, 6.3812e-03, 4.7245e-03,
         1.0816e-02, 5.2498e-01],
        [1.3108e-04, 1.7887e-04, 3.9990e-05, 1.1504e-04, 8.8646e-04, 3.7413e-04,
         5.6365e-05, 1.5851e-03, 1.3142e-03, 6.3

 37%|███████████████▌                          | 95/256 [00:02<00:04, 32.59it/s]

tensor([1.1814e-03, 1.4612e-03, 1.9849e-03, 1.4791e-03, 1.6868e-03, 1.3003e-03,
        1.2793e-03, 3.4260e-03, 2.8517e-03, 4.0201e-03, 1.2557e-03, 2.9791e-03,
        3.1346e-03, 3.2884e-03, 8.9587e-04, 1.4512e-03, 1.1340e-03, 1.6600e-03,
        5.8489e-03, 2.5179e-03, 9.4785e-01, 1.4157e-03, 1.8009e-03, 1.8292e-03,
        2.2647e-03], device='cuda:0') o_pos tensor([0.0176, 0.0090, 0.0273, 0.0124, 0.0131, 0.0154, 0.0236, 0.0255, 0.0095,
        0.0317, 0.0031, 0.0150, 0.0077, 0.0161, 0.0265, 0.0252, 0.0273, 0.0244,
        0.0264, 0.0094, 0.9875, 0.0188, 0.0274, 0.0270, 0.0205],
       device='cuda:0')
q.data tensor([[1.3108e-04, 1.7887e-04, 3.9990e-05, 1.1504e-04, 8.8646e-04, 3.7413e-04,
         5.6365e-05, 1.5851e-03, 1.3142e-03, 6.3538e-04, 1.5399e-04, 9.5981e-04,
         5.5395e-04, 3.9093e-04, 1.1399e-04, 1.8734e-04, 8.4584e-05, 7.5403e-05,
         1.8739e-03, 1.7195e-04, 8.8241e-01, 4.8661e-05, 1.2382e-04, 7.1357e-05,
         3.9742e-04, 1.0706e-01],
        [2.1485e-03, 3

 41%|████████████████▊                        | 105/256 [00:03<00:04, 36.76it/s]

tensor([[8.9784e-03, 9.3776e-03, 8.7656e-03, 9.9832e-03, 1.2323e-01, 8.7362e-03,
         9.2458e-03, 1.4007e-02, 1.0266e-02, 1.3850e-02, 1.0064e-02, 1.4877e-02,
         9.5221e-03, 1.0028e-02, 9.9052e-03, 1.0463e-02, 9.1616e-03, 7.9685e-03,
         2.2897e-01, 5.4377e-02, 1.1411e-02, 8.0608e-03, 9.4942e-03, 9.7344e-03,
         1.0776e-02, 3.6875e-01],
        [1.3108e-04, 1.7887e-04, 3.9990e-05, 1.1504e-04, 8.8646e-04, 3.7413e-04,
         5.6365e-05, 1.5851e-03, 1.3142e-03, 6.3538e-04, 1.5399e-04, 9.5981e-04,
         5.5395e-04, 3.9093e-04, 1.1399e-04, 1.8734e-04, 8.4584e-05, 7.5403e-05,
         1.8739e-03, 1.7195e-04, 8.8241e-01, 4.8661e-05, 1.2382e-04, 7.1357e-05,
         3.9742e-04, 1.0706e-01],
        [2.2300e-03, 3.4555e-03, 1.6518e-03, 2.9984e-03, 2.7129e-02, 7.5053e-03,
         2.1177e-03, 1.8383e-02, 2.4799e-02, 6.4734e-03, 3.6287e-03, 1.3453e-02,
         1.0275e-02, 8.9902e-03, 2.6349e-03, 4.5904e-03, 2.6134e-03, 2.2222e-03,
         4.2751e-02, 1.1767e-02, 1.5161e-

 45%|██████████████████▍                      | 115/256 [00:03<00:03, 39.14it/s]

y 256 tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26

 47%|███████████████████▏                     | 120/256 [00:03<00:03, 39.67it/s]

tensor([[0.0090, 0.0094, 0.0088, 0.0100, 0.1228, 0.0087, 0.0093, 0.0140, 0.0102,
         0.0139, 0.0101, 0.0149, 0.0095, 0.0100, 0.0099, 0.0105, 0.0092, 0.0080,
         0.2296, 0.0542, 0.0114, 0.0081, 0.0095, 0.0097, 0.0107, 0.3688],
        [0.0040, 0.0033, 0.0029, 0.0035, 0.0384, 0.0056, 0.0028, 0.0182, 0.0046,
         0.0382, 0.0066, 0.2504, 0.0091, 0.0047, 0.0025, 0.0045, 0.0021, 0.0018,
         0.0528, 0.0068, 0.0184, 0.0026, 0.0042, 0.0032, 0.0070, 0.5018],
        [0.0040, 0.0033, 0.0029, 0.0035, 0.0384, 0.0056, 0.0028, 0.0182, 0.0046,
         0.0382, 0.0066, 0.2504, 0.0091, 0.0047, 0.0025, 0.0045, 0.0021, 0.0018,
         0.0528, 0.0068, 0.0184, 0.0026, 0.0042, 0.0032, 0.0070, 0.5018],
        [0.0040, 0.0033, 0.0029, 0.0035, 0.0384, 0.0056, 0.0028, 0.0182, 0.0046,
         0.0382, 0.0066, 0.2504, 0.0091, 0.0047, 0.0025, 0.0045, 0.0021, 0.0018,
         0.0528, 0.0068, 0.0184, 0.0026, 0.0042, 0.0032, 0.0070, 0.5018],
        [0.0040, 0.0033, 0.0029, 0.0035, 0.0384, 0.0056,

 51%|████████████████████▊                    | 130/256 [00:03<00:03, 40.77it/s]

tensor([25, 25,  4, 25, 20, 25, 20, 20, 20, 20, 25, 25, 18, 25, 25, 25, 25, 25,
        25, 25,  4, 25, 20, 20, 25, 25, 25, 25,  4, 20, 20, 25, 25, 25, 25, 20,
        25, 25, 25, 25, 25, 25, 25, 20, 25, 25, 25,  4, 25, 20, 25, 25, 25, 20,
        20, 25, 25, 25, 25, 25, 25, 20, 25, 25, 25, 25, 18, 25, 25, 25, 25, 25,
        25, 25, 25, 25,  4, 25, 25, 25, 25, 20,  7, 25, 25, 20, 25, 20, 25, 25,
        20, 20, 25,  7, 25, 20, 20, 20, 20, 25, 20, 20,  7, 25, 25, 25, 20, 25,
        20, 25, 20, 25, 25, 20, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 20, 25,
        25, 25, 20, 20, 25, 20, 25, 25, 25, 25, 20, 25, 20, 25, 25, 20, 25, 20,
        25, 25, 20, 25, 25, 25, 25, 25, 25, 20, 20, 20, 25, 20, 25, 25, 25, 20,
        25, 25, 25, 20, 25, 25, 25, 25, 25, 20, 25, 20, 25, 20, 20, 25, 25, 25,
        25, 20, 25, 25, 25, 25, 25, 25, 25, 20, 20, 20, 25, 20, 20, 20, 25, 25,
        25, 25,  4, 20, 25, 25, 20, 20, 20, 25, 25, 25, 25, 20, 20, 25, 25, 20,
        25, 25, 25, 25, 20, 20, 25, 18, 

 55%|██████████████████████▍                  | 140/256 [00:03<00:02, 41.09it/s]

tensor([[9.0198e-03, 9.4023e-03, 8.7932e-03, 1.0013e-02, 1.2291e-01, 8.7559e-03,
         9.2851e-03, 1.3990e-02, 1.0279e-02, 1.3907e-02, 1.0120e-02, 1.4942e-02,
         9.5558e-03, 1.0037e-02, 9.9409e-03, 1.0487e-02, 9.1917e-03, 7.9969e-03,
         2.2893e-01, 5.4303e-02, 1.1457e-02, 8.0892e-03, 9.5230e-03, 9.7618e-03,
         1.0788e-02, 3.6852e-01],
        [1.3361e-04, 1.7261e-04, 1.5549e-04, 3.0863e-04, 8.2351e-01, 4.4711e-04,
         1.6276e-04, 5.0516e-04, 1.1148e-03, 1.3962e-04, 3.6488e-04, 3.3461e-03,
         6.4020e-04, 8.5893e-04, 1.8914e-04, 2.6075e-04, 2.9259e-04, 1.5211e-04,
         5.1891e-03, 2.5060e-03, 1.0989e-03, 2.0505e-04, 1.2473e-04, 1.7968e-04,
         6.9473e-04, 1.5725e-01],
        [9.0015e-03, 9.3737e-03, 8.7650e-03, 9.9789e-03, 1.2279e-01, 8.7266e-03,
         9.2616e-03, 1.3950e-02, 1.0207e-02, 1.3883e-02, 1.0099e-02, 1.4890e-02,
         9.5295e-03, 1.0007e-02, 9.9143e-03, 1.0457e-02, 9.1630e-03, 7.9719e-03,
         2.2957e-01, 5.4188e-02, 1.1444e-

 59%|████████████████████████                 | 150/256 [00:04<00:02, 41.35it/s]

tensor([20, 20, 20, 25, 20, 25, 25, 20, 25, 25, 20, 25, 25, 25,  4, 20, 20, 25,
        25, 25, 25, 25, 25, 20, 25,  7, 25, 25, 20, 20, 20, 20, 25, 20, 18, 25,
        20, 25, 25,  4, 20, 25, 25, 20, 25, 25,  7, 25, 20, 20, 25, 25, 25, 20,
        18, 18, 25, 25, 25, 25,  4, 25, 25, 25, 20, 25, 25, 20, 25, 20, 25, 20,
        20, 20, 25, 25, 20, 20, 25, 20, 20, 25, 25, 25, 25, 25,  4,  4, 25, 20,
        20, 20, 25,  7, 25, 25, 25, 20, 20, 25, 25, 25, 20, 25, 25, 25, 25, 25,
        25, 25, 25,  7, 25, 25, 20, 25, 25, 20, 25, 25, 25, 20, 25, 18, 25, 25,
        25, 25,  7, 25, 25, 25, 25, 25, 25, 25, 20,  7, 25, 18, 25, 20, 20,  5,
        25, 25, 25, 25,  5, 20, 25, 25, 20, 25, 25, 25, 25,  4, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 11, 25, 25, 25, 25, 11, 25, 25, 11, 25, 25,  4, 25,
        25,  7, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 20, 25, 25, 25, 25,
        25, 25, 25, 20,  7, 25, 25, 25, 

 61%|████████████████████████▊                | 155/256 [00:04<00:02, 41.43it/s]

tensor([[3.2356e-03, 2.9574e-03, 2.5401e-03, 2.7006e-03, 1.5361e-01, 8.1872e-03,
         2.9192e-03, 4.3219e-03, 8.4405e-03, 5.2532e-03, 5.2984e-03, 1.1624e-02,
         4.9914e-03, 5.0468e-03, 3.1039e-03, 4.0102e-03, 2.7597e-03, 3.1274e-03,
         2.1657e-01, 2.0472e-02, 2.0754e-02, 2.0352e-03, 2.8464e-03, 2.3344e-03,
         6.9651e-03, 4.9390e-01],
        [8.9784e-03, 9.3776e-03, 8.7656e-03, 9.9832e-03, 1.2323e-01, 8.7362e-03,
         9.2458e-03, 1.4007e-02, 1.0266e-02, 1.3850e-02, 1.0064e-02, 1.4877e-02,
         9.5221e-03, 1.0028e-02, 9.9052e-03, 1.0463e-02, 9.1616e-03, 7.9685e-03,
         2.2897e-01, 5.4377e-02, 1.1411e-02, 8.0608e-03, 9.4942e-03, 9.7344e-03,
         1.0776e-02, 3.6875e-01],
        [4.6237e-03, 6.7480e-03, 3.5417e-03, 6.7603e-03, 3.7327e-02, 8.1522e-03,
         4.4489e-03, 5.4423e-02, 3.7414e-02, 1.2527e-02, 6.3615e-03, 2.2934e-02,
         1.4831e-02, 9.9059e-03, 5.6369e-03, 7.0376e-03, 5.5178e-03, 3.7559e-03,
         4.6414e-02, 1.8660e-02, 9.8283e-

 64%|██████████████████████████▍              | 165/256 [00:04<00:02, 41.61it/s]

y 256 tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26

 68%|████████████████████████████             | 175/256 [00:04<00:01, 41.56it/s]

tensor([[4.2560e-03, 5.2174e-03, 4.9181e-03, 5.6069e-03, 4.8821e-02, 1.7868e-02,
         4.9797e-03, 1.1325e-02, 1.0022e-01, 8.4522e-03, 1.1570e-02, 1.8762e-02,
         9.9349e-03, 6.2390e-03, 6.0905e-03, 7.0729e-03, 5.2751e-03, 6.2304e-03,
         9.6876e-02, 3.7689e-02, 3.1498e-02, 4.1943e-03, 6.3812e-03, 4.7245e-03,
         1.0816e-02, 5.2498e-01],
        [2.6191e-03, 2.4458e-03, 2.7099e-03, 3.0577e-03, 4.5264e-01, 4.0598e-03,
         3.1085e-03, 4.3258e-03, 3.1610e-03, 5.2742e-03, 1.0042e-02, 8.6757e-03,
         6.7951e-03, 6.7442e-03, 3.2542e-03, 4.3334e-03, 3.2439e-03, 2.5903e-03,
         6.1021e-02, 2.4431e-02, 6.3343e-03, 2.5037e-03, 2.0398e-03, 2.6049e-03,
         5.4337e-03, 3.6655e-01],
        [2.1421e-03, 3.9011e-03, 1.6189e-03, 3.1832e-03, 1.4735e-02, 3.9593e-03,
         2.0471e-03, 2.7704e-01, 8.3808e-03, 2.2162e-02, 2.3358e-03, 1.7069e-02,
         1.0852e-02, 8.7586e-03, 2.7265e-03, 4.6839e-03, 2.1085e-03, 1.2150e-03,
         2.3396e-02, 9.7059e-03, 5.8301e-

 70%|████████████████████████████▊            | 180/256 [00:04<00:01, 41.65it/s]

tensor([25, 25, 25,  7, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 20, 25,  7, 25, 25, 25, 20, 25, 25, 25, 25, 20, 20, 25, 25,
        25, 25, 25, 25, 25, 20, 25, 25, 25, 25, 25,  7, 25,  7, 25,  7,  7,  7,
        25, 20, 25, 20, 25, 25, 25,  4, 25, 25, 25, 20, 20, 20, 25, 25, 20, 25,
        25, 25, 25, 20, 25, 25, 20, 25, 25, 25, 25, 25, 25, 25, 20, 20, 25, 25,
        25, 25, 25, 20, 25, 25, 25, 25,  7,  4, 25, 25, 18, 20, 20, 25, 25, 25,
        25, 20, 20, 25, 25, 25,  4, 25, 20, 25, 25, 20, 25, 20, 25, 20, 20, 25,
        25, 25, 20, 20, 25, 25, 25, 25, 25, 25, 25, 20, 20, 20, 25, 25, 25, 25,
        25, 20, 20, 25,  4, 25, 20, 25, 25, 25, 20, 20, 20, 20, 25, 20, 25, 25,
        20, 25,  5,  5, 25, 25, 20, 25, 25, 25, 20, 20, 25, 20, 20, 20, 25,  4,
        25, 25, 20, 20, 25, 25, 20, 25, 20, 25, 25, 20,  7,  7,  7, 20,  7,  7,
        25,  7,  7, 25, 20, 25, 20, 20, 25, 25, 25, 20, 20, 25, 20, 20, 25, 20,
        25, 25, 25, 25, 25, 20, 20, 25, 

 74%|██████████████████████████████▍          | 190/256 [00:05<00:01, 41.94it/s]

tensor([[5.8709e-03, 5.0209e-03, 4.5192e-03, 5.9271e-03, 2.3393e-01, 5.4542e-03,
         5.1907e-03, 1.0626e-02, 4.8474e-03, 1.2587e-02, 7.3321e-03, 3.9231e-02,
         8.2976e-03, 7.8363e-03, 5.0475e-03, 6.0393e-03, 4.7551e-03, 3.6827e-03,
         1.3572e-01, 2.2151e-02, 1.0541e-02, 4.7285e-03, 5.0174e-03, 5.2111e-03,
         8.6038e-03, 4.3184e-01],
        [1.3108e-04, 1.7887e-04, 3.9990e-05, 1.1504e-04, 8.8646e-04, 3.7413e-04,
         5.6365e-05, 1.5851e-03, 1.3142e-03, 6.3538e-04, 1.5399e-04, 9.5981e-04,
         5.5395e-04, 3.9093e-04, 1.1399e-04, 1.8734e-04, 8.4584e-05, 7.5403e-05,
         1.8739e-03, 1.7195e-04, 8.8241e-01, 4.8661e-05, 1.2382e-04, 7.1357e-05,
         3.9742e-04, 1.0706e-01],
        [1.3108e-04, 1.7887e-04, 3.9990e-05, 1.1504e-04, 8.8646e-04, 3.7413e-04,
         5.6365e-05, 1.5851e-03, 1.3142e-03, 6.3538e-04, 1.5399e-04, 9.5981e-04,
         5.5395e-04, 3.9093e-04, 1.1399e-04, 1.8734e-04, 8.4584e-05, 7.5403e-05,
         1.8739e-03, 1.7195e-04, 8.8241e-

 78%|████████████████████████████████         | 200/256 [00:05<00:01, 41.82it/s]

p tensor([0.0048, 0.0083, 0.0071, 0.0059, 0.4844, 0.0078, 0.0084, 0.0673, 0.0096,
        0.0438, 0.0113, 0.0802, 0.0178, 0.0134, 0.0072, 0.0126, 0.0048, 0.0064,
        0.0778, 0.0504, 0.0347, 0.0066, 0.0074, 0.0091, 0.0128],
       device='cuda:0') o_pos tensor([0.0847, 0.0817, 0.1011, 0.1006, 0.4700, 0.0454, 0.1025, 0.1547, 0.0192,
        0.1771, 0.0942, 0.2394, 0.1085, 0.0616, 0.0927, 0.0770, 0.1109, 0.0964,
        0.0858, 0.0502, 0.1025, 0.1079, 0.0835, 0.0970, 0.0828],
       device='cuda:0')
q.data tensor([[2.0264e-03, 1.9880e-03, 1.8099e-03, 2.7057e-03, 2.7416e-01, 3.5326e-03,
         1.7236e-03, 1.8031e-02, 3.9704e-03, 1.1917e-02, 4.3409e-03, 4.6489e-02,
         6.5787e-03, 5.2338e-03, 1.9148e-03, 3.3710e-03, 2.0626e-03, 1.2836e-03,
         2.8756e-02, 1.0138e-02, 1.1323e-02, 1.9024e-03, 1.9392e-03, 2.1483e-03,
         5.1368e-03, 5.4551e-01],
        [4.0361e-03, 3.2889e-03, 2.8791e-03, 3.5410e-03, 3.9421e-02, 5.6974e-03,
         2.8395e-03, 1.8179e-02, 4.6234e-03, 3.7

 82%|█████████████████████████████████▋       | 210/256 [00:05<00:01, 42.03it/s]

tensor([[7.7765e-04, 1.0114e-03, 1.9207e-03, 1.4120e-03, 2.5602e-01, 1.0196e-03,
         1.2484e-03, 1.7376e-03, 3.6253e-04, 3.3274e-03, 2.1662e-02, 1.3658e-03,
         4.1262e-03, 1.9064e-03, 1.3869e-03, 4.9373e-03, 8.7418e-04, 1.0275e-03,
         3.2719e-02, 8.7150e-03, 2.0461e-03, 1.6281e-03, 1.0913e-03, 1.5025e-03,
         1.3693e-03, 6.4480e-01],
        [1.2119e-03, 1.2926e-03, 3.1115e-03, 2.0325e-03, 5.5199e-02, 2.1677e-03,
         1.6928e-03, 1.0314e-03, 6.7423e-04, 5.5897e-03, 9.8827e-02, 3.9124e-03,
         1.1026e-02, 1.8002e-03, 1.9357e-03, 8.6380e-03, 1.2306e-03, 1.8319e-03,
         4.0475e-02, 1.2075e-02, 3.2242e-03, 2.2279e-03, 1.2671e-03, 2.3617e-03,
         2.5034e-03, 7.3266e-01],
        [8.7792e-04, 1.0287e-03, 1.6480e-03, 1.1128e-03, 2.9865e-01, 6.5667e-04,
         1.1468e-03, 1.1038e-03, 5.3828e-04, 2.5240e-03, 1.5588e-02, 6.0116e-03,
         4.9380e-03, 1.2859e-03, 1.3151e-03, 4.3932e-03, 7.8798e-04, 1.0685e-03,
         2.4310e-02, 7.1835e-03, 2.0128e-

 84%|██████████████████████████████████▍      | 215/256 [00:05<00:01, 39.50it/s]

tensor([0.0176, 0.0090, 0.0273, 0.0124, 0.0131, 0.0154, 0.0236, 0.0255, 0.0095,
        0.0317, 0.0031, 0.0150, 0.0077, 0.0161, 0.0265, 0.0252, 0.0273, 0.0244,
        0.0264, 0.0094, 0.9875, 0.0188, 0.0274, 0.0270, 0.0205],
       device='cuda:0')
q.data tensor([[1.3108e-04, 1.7887e-04, 3.9990e-05, 1.1504e-04, 8.8646e-04, 3.7413e-04,
         5.6365e-05, 1.5851e-03, 1.3142e-03, 6.3538e-04, 1.5399e-04, 9.5981e-04,
         5.5395e-04, 3.9093e-04, 1.1399e-04, 1.8734e-04, 8.4584e-05, 7.5403e-05,
         1.8739e-03, 1.7195e-04, 8.8241e-01, 4.8661e-05, 1.2382e-04, 7.1357e-05,
         3.9742e-04, 1.0706e-01],
        [4.6116e-03, 8.8819e-03, 6.6797e-03, 9.1315e-03, 4.3790e-02, 1.9864e-02,
         6.6530e-03, 4.0546e-02, 2.9973e-02, 1.9781e-02, 1.2360e-02, 1.4900e-02,
         2.4561e-02, 3.3691e-02, 8.2492e-03, 1.5953e-02, 7.4745e-03, 7.4941e-03,
         3.5967e-02, 5.4236e-02, 3.3481e-02, 6.2949e-03, 7.3388e-03, 8.8497e-03,
         1.7982e-02, 5.2126e-01],
        [1.3108e-04, 1.7887e

 87%|███████████████████████████████████▋     | 223/256 [00:05<00:00, 36.56it/s]

tensor([0.0073, 0.0099, 0.0102, 0.0122, 0.0531, 0.0115, 0.0138, 0.4096, 0.0152,
        0.0357, 0.0126, 0.0310, 0.0295, 0.0373, 0.0095, 0.0203, 0.0154, 0.0099,
        0.0330, 0.0378, 0.1363, 0.0121, 0.0120, 0.0146, 0.0103],
       device='cuda:0') o_pos tensor([0.0586, 0.0543, 0.0894, 0.0948, 0.0363, 0.0154, 0.0747, 0.7900, 0.0240,
        0.0873, 0.0138, 0.0474, 0.0937, 0.0555, 0.0797, 0.0564, 0.0700, 0.0829,
        0.0556, 0.0530, 0.3146, 0.0645, 0.0480, 0.0665, 0.0363],
       device='cuda:0')
q.data tensor([[1.3324e-03, 2.7800e-03, 1.0553e-03, 2.1716e-03, 1.1186e-02, 2.7450e-03,
         1.3235e-03, 3.9025e-01, 5.8970e-03, 1.8340e-02, 1.3166e-03, 1.0247e-02,
         7.8470e-03, 7.1415e-03, 1.8699e-03, 3.4379e-03, 1.3838e-03, 7.4519e-04,
         1.7260e-02, 8.0043e-03, 4.4313e-02, 9.2305e-04, 1.8829e-03, 2.4275e-03,
         3.9262e-03, 4.5019e-01],
        [1.3324e-03, 2.7800e-03, 1.0553e-03, 2.1716e-03, 1.1186e-02, 2.7450e-03,
         1.3235e-03, 3.9025e-01, 5.8970e-03, 1.834

 89%|████████████████████████████████████▎    | 227/256 [00:06<00:00, 35.70it/s]

y 256 tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26

 90%|████████████████████████████████████▉    | 231/256 [00:06<00:01, 21.07it/s]

y 256 tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26

 93%|██████████████████████████████████████▎  | 239/256 [00:06<00:00, 25.91it/s]

tensor([20, 25, 20, 25, 20, 25, 25, 25, 25, 25, 20, 25, 25, 20, 25, 25, 25, 25,
        25, 20, 25, 25, 25, 20, 25, 25, 20, 20, 25, 25, 25, 25, 25, 25, 25, 20,
        25, 25, 25, 20, 25, 25, 25, 25, 25, 20, 25, 25, 25, 25, 20, 25, 20, 25,
        25, 25, 20, 20, 25, 25, 25, 25, 25, 25, 25,  4, 25, 25, 25, 20, 25,  4,
        20, 20, 20, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 20, 20, 25, 25, 20,
        25, 25, 25, 20, 20, 20, 25, 25, 20, 20, 20, 25, 25, 25, 20, 25, 18, 18,
        20, 20, 25, 20, 25, 25, 25, 20, 18, 20, 20, 20, 20, 25, 25, 25, 25, 25,
        25, 25, 20, 20, 25, 20, 25, 20, 20, 25, 25, 20, 25, 20, 20, 25, 25, 25,
        25, 25, 20, 20, 25, 25, 20, 25, 20, 25, 25, 25, 20, 20, 25, 20, 25, 25,
        25, 25, 25, 20, 25, 25, 25, 25, 25, 25, 25, 20, 25, 20, 25, 20, 20, 25,
        25, 18, 20, 20, 20, 20, 25, 20, 25, 25, 20, 20, 25,  4, 25, 25, 20, 25,
        20, 20, 20, 20, 25, 20, 25, 25, 20, 18, 20, 25, 25, 20, 25, 25, 25, 25,
        20, 18, 20, 20, 25, 20, 20, 25, 

 96%|███████████████████████████████████████▌ | 247/256 [00:06<00:00, 29.17it/s]

tensor([[1.4074e-03, 1.0289e-03, 2.5981e-03, 1.9680e-03, 1.0893e-02, 7.0157e-04,
         2.0115e-03, 1.9233e-02, 1.3813e-02, 2.2404e-03, 1.1677e-03, 3.6971e-03,
         9.2828e-03, 1.5090e-03, 1.8984e-03, 2.8841e-03, 2.4094e-03, 2.2353e-03,
         1.4209e-02, 5.4744e-03, 8.3706e-02, 1.2305e-03, 1.5842e-03, 2.2867e-03,
         3.3028e-03, 8.0723e-01],
        [1.7091e-04, 1.5950e-04, 3.9925e-04, 1.4348e-04, 1.6641e-03, 2.7205e-04,
         2.2003e-04, 4.5893e-04, 3.5042e-04, 1.4234e-03, 1.8196e-04, 2.1139e-03,
         5.4699e-04, 6.4146e-04, 2.2147e-04, 4.8594e-04, 2.5987e-04, 3.2667e-04,
         2.0655e-03, 2.3655e-04, 5.2042e-01, 2.1654e-04, 3.3096e-04, 4.0015e-04,
         4.5177e-04, 4.6583e-01],
        [1.0914e-03, 8.6062e-04, 1.3490e-03, 1.2299e-03, 1.4647e-02, 6.7526e-04,
         1.3347e-03, 1.5430e-02, 1.9038e-03, 1.0491e-02, 2.7498e-03, 1.0672e-01,
         7.6713e-03, 1.3377e-03, 1.1904e-03, 1.7882e-03, 1.6186e-03, 1.4875e-03,
         9.2642e-03, 2.1989e-03, 1.3914e-

100%|████████████████████████████████████████▊| 255/256 [00:07<00:00, 31.16it/s]

tensor([[1.1425e-03, 1.2721e-03, 1.2580e-03, 2.0297e-03, 1.0614e-01, 6.9895e-04,
         1.9163e-03, 3.3852e-03, 1.6905e-03, 2.9232e-03, 1.9412e-03, 1.3562e-03,
         3.1816e-03, 5.1576e-04, 1.9783e-03, 2.2028e-03, 1.4092e-03, 1.3757e-03,
         2.8198e-01, 4.7826e-02, 1.0311e-03, 9.9863e-04, 1.2717e-03, 1.3237e-03,
         2.1172e-03, 5.2704e-01],
        [6.9081e-04, 6.5133e-04, 1.6203e-03, 1.2293e-03, 2.9529e-03, 3.7551e-03,
         1.0067e-03, 2.2458e-03, 1.7706e-03, 6.6892e-03, 4.3466e-02, 1.9597e-03,
         5.9523e-03, 1.8980e-03, 8.6310e-04, 4.6095e-03, 9.0561e-04, 1.3929e-03,
         1.3255e-02, 6.2870e-03, 1.6436e-02, 1.3520e-03, 9.7820e-04, 1.5346e-03,
         1.2693e-03, 8.7523e-01],
        [5.5009e-04, 6.0590e-04, 7.8563e-04, 7.3269e-04, 4.4876e-01, 3.1135e-04,
         7.2162e-04, 2.0177e-03, 3.6992e-04, 2.6181e-03, 1.7041e-03, 1.4663e-02,
         2.9302e-03, 1.2083e-03, 8.0204e-04, 1.7439e-03, 6.0363e-04, 5.5261e-04,
         1.7108e-02, 2.3658e-03, 6.9731e-

100%|█████████████████████████████████████████| 256/256 [00:07<00:00, 35.29it/s]


tensor([[1.3108e-04, 1.7887e-04, 3.9990e-05, 1.1504e-04, 8.8646e-04, 3.7413e-04,
         5.6365e-05, 1.5851e-03, 1.3142e-03, 6.3538e-04, 1.5399e-04, 9.5981e-04,
         5.5395e-04, 3.9093e-04, 1.1399e-04, 1.8734e-04, 8.4584e-05, 7.5403e-05,
         1.8739e-03, 1.7195e-04, 8.8241e-01, 4.8661e-05, 1.2382e-04, 7.1357e-05,
         3.9742e-04, 1.0706e-01],
        [8.9784e-03, 9.3776e-03, 8.7656e-03, 9.9832e-03, 1.2323e-01, 8.7362e-03,
         9.2458e-03, 1.4007e-02, 1.0266e-02, 1.3850e-02, 1.0064e-02, 1.4877e-02,
         9.5221e-03, 1.0028e-02, 9.9052e-03, 1.0463e-02, 9.1616e-03, 7.9685e-03,
         2.2897e-01, 5.4377e-02, 1.1411e-02, 8.0608e-03, 9.4942e-03, 9.7344e-03,
         1.0776e-02, 3.6875e-01],
        [5.4790e-03, 5.1092e-03, 4.2070e-03, 5.4097e-03, 3.1947e-02, 6.5193e-03,
         4.6114e-03, 2.7193e-02, 1.3018e-02, 2.1924e-02, 6.0993e-03, 2.1321e-01,
         9.1837e-03, 4.8397e-03, 3.8166e-03, 4.9704e-03, 3.4036e-03, 2.7005e-03,
         1.2075e-01, 1.2742e-02, 1.1501e-

  checkpoint = torch.load(checkpoint_path)


len_seen_indices 1848
dataset_data (1848, 26, 26, 3)
dataset_data (10625, 26, 26, 3)
lb_dset <semilearn.datasets.cv_datasets.datasetbase.BasicDataset object at 0x7f9e1d211460> lb_dset 6981
./saved_models/openset_cv/iomatch_selfdeploy_merge_month_2_4_2weektest_tls_extra_ep250_bs64/1999_model.pth
Model at step 1999 loaded!


  0%|                                                   | 0/180 [00:00<?, ?it/s]

y 256 tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])


  3%|█▏                                         | 5/180 [00:00<00:08, 21.72it/s]

p tensor([3.8062e-04, 3.2625e-04, 2.1666e-04, 4.7234e-04, 9.4003e-01, 5.6738e-03,
        2.7686e-04, 3.7534e-03, 1.5512e-02, 9.2677e-03, 2.4534e-04, 8.2530e-04,
        1.3054e-03, 6.9715e-04, 3.5499e-04, 6.0848e-04, 3.2876e-04, 1.5195e-04,
        1.3710e-04, 7.5778e-04, 9.0510e-03, 3.3913e-04, 3.3247e-04, 2.2823e-04,
        8.7258e-03], device='cuda:0') o_pos tensor([2.8424e-02, 4.7349e-02, 3.7588e-02, 4.2651e-02, 9.6387e-01, 7.5403e-03,
        5.9006e-02, 3.3991e-02, 6.1438e-02, 4.2022e-02, 8.4256e-03, 6.4601e-03,
        4.8085e-02, 5.6070e-03, 4.6627e-02, 2.1268e-02, 2.2975e-02, 3.7025e-02,
        4.7313e-03, 2.3955e-02, 5.9988e-02, 9.5442e-04, 2.8487e-02, 5.5563e-02,
        4.2511e-02], device='cuda:0')
q.data tensor([[1.1570e-05, 3.3014e-05, 3.8393e-05, 3.8007e-05, 8.7389e-01, 2.4154e-04,
         1.5437e-05, 3.9897e-04, 2.0754e-03, 4.3393e-03, 1.6920e-05, 8.5951e-06,
         1.2257e-04, 2.1273e-04, 1.3943e-05, 6.8840e-05, 5.6659e-05, 4.9523e-05,
         8.5427e-06, 1.469

  7%|███                                       | 13/180 [00:00<00:05, 31.98it/s]

tensor([25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 

 12%|████▉                                     | 21/180 [00:00<00:04, 35.56it/s]

tensor([[9.8211e-07, 1.0762e-07, 1.6192e-07, 3.9646e-07, 7.9039e-05, 4.4459e-04,
         3.5117e-07, 1.9809e-07, 9.5427e-01, 7.1157e-06, 1.7465e-07, 2.0882e-06,
         9.9361e-07, 1.7987e-08, 2.5877e-07, 5.6646e-07, 5.2412e-07, 2.8077e-07,
         3.7558e-08, 6.0837e-06, 3.3146e-05, 5.7477e-09, 4.7863e-07, 9.2728e-07,
         7.9664e-07, 4.5147e-02],
        [9.5198e-07, 1.0177e-07, 1.5406e-07, 3.7722e-07, 7.4016e-05, 4.3139e-04,
         3.3125e-07, 1.8755e-07, 9.5532e-01, 6.5965e-06, 1.6497e-07, 2.0677e-06,
         9.4583e-07, 1.6812e-08, 2.4543e-07, 5.3801e-07, 5.0318e-07, 2.6932e-07,
         3.5736e-08, 5.9249e-06, 3.1652e-05, 5.4674e-09, 4.5709e-07, 8.8202e-07,
         7.5491e-07, 4.4125e-02],
        [6.7063e-07, 6.7175e-08, 1.0254e-07, 2.6022e-07, 5.4715e-05, 3.5898e-04,
         2.2604e-07, 1.1604e-07, 9.5929e-01, 4.5827e-06, 1.1003e-07, 1.4729e-06,
         6.3620e-07, 1.0610e-08, 1.6082e-07, 3.7019e-07, 3.4176e-07, 1.8286e-07,
         2.5392e-08, 4.2252e-06, 2.9696e-

 16%|██████▊                                   | 29/180 [00:00<00:04, 36.95it/s]

tensor([[2.5541e-05, 1.6782e-05, 4.8757e-06, 9.2466e-06, 4.2384e-04, 1.3403e-04,
         8.1730e-06, 1.2071e-05, 7.4513e-04, 4.8125e-05, 3.1462e-05, 4.5231e-05,
         2.8514e-05, 1.7511e-05, 1.2477e-05, 1.0309e-05, 8.5016e-06, 1.0110e-05,
         1.2299e-03, 3.6675e-06, 9.0466e-01, 9.1402e-06, 1.7022e-05, 6.3782e-06,
         2.1661e-05, 9.2456e-02],
        [2.7926e-05, 1.8485e-05, 5.4433e-06, 1.0247e-05, 4.4195e-04, 1.4692e-04,
         9.0597e-06, 1.3258e-05, 8.0533e-04, 5.1990e-05, 3.4299e-05, 4.9893e-05,
         3.1232e-05, 1.9177e-05, 1.3782e-05, 1.1382e-05, 9.3855e-06, 1.1181e-05,
         1.3156e-03, 4.0974e-06, 9.0167e-01, 1.0126e-05, 1.8749e-05, 7.0987e-06,
         2.3606e-05, 9.5242e-02],
        [7.2223e-03, 7.1352e-03, 4.9573e-03, 5.9985e-03, 1.4757e-02, 2.6646e-02,
         5.6772e-03, 6.4191e-03, 5.0599e-02, 1.0861e-02, 8.9681e-03, 1.4810e-02,
         8.6548e-03, 7.6369e-03, 6.7470e-03, 5.7176e-03, 5.4859e-03, 6.1188e-03,
         3.9913e-02, 5.0503e-03, 3.7397e-

 21%|████████▋                                 | 37/180 [00:01<00:03, 37.54it/s]

p tensor([0.0016, 0.0022, 0.0021, 0.0012, 0.1037, 0.0019, 0.0020, 0.0038, 0.0047,
        0.0039, 0.0024, 0.0019, 0.0037, 0.0026, 0.0015, 0.0036, 0.0017, 0.0017,
        0.0062, 0.0017, 0.8042, 0.0026, 0.0014, 0.0020, 0.0359],
       device='cuda:0') o_pos tensor([0.1506, 0.1592, 0.1768, 0.1487, 0.1982, 0.1390, 0.1699, 0.1829, 0.1070,
        0.1626, 0.1567, 0.1566, 0.1464, 0.1240, 0.1431, 0.1727, 0.1667, 0.1605,
        0.1372, 0.0948, 0.8336, 0.0785, 0.1719, 0.1620, 0.1755],
       device='cuda:0')
q.data tensor([[4.0175e-03, 3.6979e-03, 2.2923e-03, 2.9065e-03, 2.8441e-02, 7.1692e-03,
         2.7041e-03, 5.6225e-03, 9.9778e-03, 6.6750e-03, 5.5165e-03, 6.0178e-03,
         4.9134e-03, 6.0366e-03, 3.2033e-03, 2.9487e-03, 3.0514e-03, 2.7995e-03,
         1.8488e-02, 2.3430e-03, 5.5850e-01, 2.8923e-03, 2.9939e-03, 2.0959e-03,
         7.1575e-03, 2.9753e-01],
        [4.0735e-03, 3.7491e-03, 2.3229e-03, 2.9494e-03, 2.8385e-02, 7.2460e-03,
         2.7559e-03, 5.7177e-03, 9.9847e-03, 6.7

 25%|██████████▌                               | 45/180 [00:01<00:03, 38.04it/s]

y 256 tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26

 29%|████████████▎                             | 53/180 [00:01<00:03, 38.25it/s]

tensor([[7.2254e-04, 1.4324e-03, 1.5016e-03, 1.3492e-03, 1.0068e-01, 4.8319e-03,
         1.8946e-03, 3.0555e-03, 4.7146e-03, 2.6938e-01, 2.2784e-03, 6.3296e-04,
         2.4018e-03, 3.4303e-03, 1.8225e-03, 3.0221e-03, 1.1024e-03, 1.2905e-03,
         2.3675e-03, 8.4023e-04, 1.9230e-02, 1.3276e-03, 1.5551e-03, 1.6215e-03,
         3.2995e-03, 5.6422e-01],
        [7.2254e-04, 1.4324e-03, 1.5016e-03, 1.3492e-03, 1.0068e-01, 4.8319e-03,
         1.8946e-03, 3.0555e-03, 4.7146e-03, 2.6938e-01, 2.2784e-03, 6.3296e-04,
         2.4018e-03, 3.4303e-03, 1.8225e-03, 3.0221e-03, 1.1024e-03, 1.2905e-03,
         2.3675e-03, 8.4023e-04, 1.9230e-02, 1.3276e-03, 1.5551e-03, 1.6215e-03,
         3.2995e-03, 5.6422e-01],
        [7.2254e-04, 1.4324e-03, 1.5016e-03, 1.3492e-03, 1.0068e-01, 4.8319e-03,
         1.8946e-03, 3.0555e-03, 4.7146e-03, 2.6938e-01, 2.2784e-03, 6.3296e-04,
         2.4018e-03, 3.4303e-03, 1.8225e-03, 3.0221e-03, 1.1024e-03, 1.2905e-03,
         2.3675e-03, 8.4023e-04, 1.9230e-

 34%|██████████████▏                           | 61/180 [00:01<00:03, 38.28it/s]

tensor([[4.7479e-03, 6.5748e-03, 5.5552e-03, 5.5165e-03, 8.7999e-03, 6.8545e-03,
         4.1519e-03, 3.6255e-01, 6.0781e-03, 1.2071e-02, 5.1275e-03, 1.8373e-02,
         7.2580e-03, 8.4423e-03, 4.7342e-03, 6.8295e-03, 5.4643e-03, 3.4035e-03,
         1.3525e-02, 7.4405e-03, 1.4317e-02, 1.3051e-02, 5.1999e-03, 5.4390e-03,
         7.1931e-02, 3.8657e-01],
        [4.7600e-04, 1.0189e-03, 1.0042e-03, 1.1334e-03, 4.6756e-01, 3.2189e-03,
         5.9156e-04, 1.2501e-02, 1.2000e-02, 1.4094e-02, 5.8619e-04, 1.0760e-03,
         2.3687e-03, 4.7238e-03, 5.3269e-04, 1.4135e-03, 1.3328e-03, 1.0386e-03,
         4.7639e-04, 1.9899e-03, 1.1504e-02, 1.7004e-04, 5.2690e-04, 6.7808e-04,
         1.1534e-02, 4.4645e-01],
        [4.6286e-03, 5.3866e-03, 5.5264e-03, 5.3967e-03, 1.1043e-02, 7.7988e-02,
         4.5664e-03, 4.7195e-03, 2.2207e-01, 1.3675e-02, 6.7222e-03, 1.5638e-02,
         7.0996e-03, 4.8882e-03, 5.1729e-03, 5.5580e-03, 5.0115e-03, 6.5013e-03,
         2.0402e-02, 7.4018e-03, 1.1499e-

 38%|████████████████                          | 69/180 [00:01<00:02, 38.31it/s]

p tensor([0.0081, 0.0100, 0.0074, 0.0083, 0.0305, 0.0295, 0.0086, 0.0434, 0.0614,
        0.0248, 0.0069, 0.1586, 0.0213, 0.0199, 0.0058, 0.0239, 0.0098, 0.0082,
        0.0545, 0.0077, 0.3473, 0.0297, 0.0115, 0.0082, 0.0549],
       device='cuda:0') o_pos tensor([0.2078, 0.1788, 0.1318, 0.1684, 0.0640, 0.2400, 0.1702, 0.2580, 0.1897,
        0.1467, 0.1915, 0.6330, 0.1907, 0.1724, 0.1365, 0.1872, 0.2351, 0.1919,
        0.2077, 0.1118, 0.5228, 0.1922, 0.1761, 0.1817, 0.1618],
       device='cuda:0')
q.data tensor([[5.9746e-03, 6.5125e-03, 4.8945e-03, 5.7979e-03, 5.2243e-03, 1.5074e-02,
         5.6598e-03, 2.3404e-02, 2.3993e-02, 4.5423e-03, 8.0325e-03, 1.2591e-01,
         7.5891e-03, 1.2997e-02, 5.8659e-03, 4.9761e-03, 4.7551e-03, 3.4961e-03,
         2.5260e-02, 4.7375e-03, 1.9982e-01, 1.8904e-02, 5.5928e-03, 4.6947e-03,
         1.7165e-02, 4.4912e-01],
        [4.6945e-03, 6.5144e-03, 5.5007e-03, 5.4532e-03, 8.8060e-03, 6.7443e-03,
         4.1039e-03, 3.6632e-01, 5.9555e-03, 1.2

 43%|█████████████████▉                        | 77/180 [00:02<00:02, 38.21it/s]

y 256 tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26

 47%|███████████████████▊                      | 85/180 [00:02<00:02, 38.02it/s]

tensor([[2.1672e-04, 1.7199e-04, 1.7326e-04, 2.8822e-04, 3.2376e-01, 3.8040e-03,
         2.6708e-04, 6.6810e-04, 4.8218e-02, 9.1841e-03, 1.0727e-04, 7.2497e-05,
         5.0357e-04, 6.8129e-05, 2.2866e-04, 2.4363e-04, 1.8439e-04, 1.4586e-04,
         2.2326e-04, 3.8880e-04, 1.8025e-02, 2.2433e-05, 2.1914e-04, 2.5834e-04,
         9.8081e-04, 5.9157e-01],
        [1.0844e-03, 1.4609e-03, 1.1694e-03, 1.7180e-03, 2.1686e-03, 9.1468e-04,
         1.4272e-03, 4.2939e-01, 7.1375e-04, 2.5107e-03, 9.6801e-04, 4.5641e-03,
         8.6683e-03, 1.1018e-03, 9.6454e-04, 3.2657e-03, 1.9212e-03, 1.1367e-03,
         1.4908e-03, 1.0498e-03, 3.9155e-03, 4.3819e-03, 1.4595e-03, 1.3092e-03,
         6.0165e-02, 4.6108e-01],
        [1.0814e-03, 4.5085e-04, 7.0077e-04, 7.3767e-04, 1.8496e-03, 6.2834e-02,
         7.7725e-04, 1.1772e-03, 2.4333e-01, 3.3641e-03, 7.0652e-04, 1.5732e-03,
         9.4616e-04, 3.2424e-04, 6.2665e-04, 1.3123e-03, 8.0369e-04, 8.6308e-04,
         9.1546e-03, 9.9527e-04, 1.1147e-

 52%|█████████████████████▋                    | 93/180 [00:02<00:02, 38.22it/s]

tensor([[0.0047, 0.0055, 0.0058, 0.0055, 0.0113, 0.0823, 0.0047, 0.0050, 0.2315,
         0.0141, 0.0068, 0.0159, 0.0072, 0.0050, 0.0053, 0.0057, 0.0052, 0.0067,
         0.0196, 0.0080, 0.1041, 0.0072, 0.0065, 0.0056, 0.0069, 0.4140],
        [0.0050, 0.0069, 0.0058, 0.0058, 0.0094, 0.0072, 0.0044, 0.3533, 0.0063,
         0.0126, 0.0054, 0.0185, 0.0076, 0.0088, 0.0050, 0.0071, 0.0057, 0.0036,
         0.0142, 0.0077, 0.0154, 0.0133, 0.0054, 0.0057, 0.0724, 0.3877],
        [0.0046, 0.0063, 0.0054, 0.0053, 0.0084, 0.0066, 0.0040, 0.3689, 0.0059,
         0.0115, 0.0049, 0.0181, 0.0070, 0.0082, 0.0045, 0.0066, 0.0053, 0.0033,
         0.0131, 0.0072, 0.0137, 0.0127, 0.0050, 0.0052, 0.0716, 0.3869],
        [0.0049, 0.0068, 0.0057, 0.0057, 0.0090, 0.0071, 0.0043, 0.3519, 0.0063,
         0.0122, 0.0053, 0.0190, 0.0075, 0.0088, 0.0049, 0.0070, 0.0056, 0.0035,
         0.0142, 0.0076, 0.0154, 0.0137, 0.0054, 0.0056, 0.0720, 0.3905],
        [0.0051, 0.0059, 0.0062, 0.0060, 0.0120, 0.0810,

 54%|██████████████████████▋                   | 97/180 [00:02<00:02, 37.99it/s]

p tensor([0.0043, 0.0058, 0.0044, 0.0063, 0.0148, 0.0067, 0.0059, 0.5982, 0.0054,
        0.0104, 0.0041, 0.0150, 0.0307, 0.0080, 0.0041, 0.0142, 0.0072, 0.0041,
        0.0079, 0.0067, 0.0143, 0.0168, 0.0064, 0.0066, 0.1914],
       device='cuda:0') o_pos tensor([0.2506, 0.2532, 0.2632, 0.2759, 0.1600, 0.1342, 0.2415, 0.7282, 0.1269,
        0.2504, 0.2283, 0.2826, 0.2845, 0.1379, 0.2368, 0.2275, 0.2633, 0.2750,
        0.1835, 0.1550, 0.2688, 0.2468, 0.2265, 0.1993, 0.3117],
       device='cuda:0')
q.data tensor([[0.0046, 0.0063, 0.0054, 0.0053, 0.0086, 0.0066, 0.0040, 0.3720, 0.0058,
         0.0118, 0.0049, 0.0176, 0.0070, 0.0081, 0.0045, 0.0066, 0.0053, 0.0033,
         0.0130, 0.0072, 0.0135, 0.0126, 0.0050, 0.0053, 0.0718, 0.3840],
        [0.0020, 0.0016, 0.0009, 0.0012, 0.0127, 0.0041, 0.0011, 0.0022, 0.0065,
         0.0050, 0.0024, 0.0022, 0.0022, 0.0020, 0.0014, 0.0013, 0.0012, 0.0012,
         0.0198, 0.0008, 0.6168, 0.0015, 0.0015, 0.0010, 0.0028, 0.3045],
        [0.0017

 56%|███████████████████████                  | 101/180 [00:02<00:02, 38.06it/s]

tensor([[1.2145e-03, 5.0296e-04, 7.7989e-04, 7.9642e-04, 1.9816e-03, 6.7028e-02,
         8.6487e-04, 1.3825e-03, 2.3461e-01, 3.6372e-03, 7.9848e-04, 1.8139e-03,
         1.0328e-03, 3.6111e-04, 7.0048e-04, 1.4330e-03, 9.0045e-04, 9.7558e-04,
         1.0424e-02, 1.0938e-03, 1.0783e-01, 7.6658e-04, 1.5911e-03, 1.0627e-03,
         1.0003e-03, 5.5543e-01],
        [7.2254e-04, 1.4324e-03, 1.5016e-03, 1.3492e-03, 1.0068e-01, 4.8319e-03,
         1.8946e-03, 3.0555e-03, 4.7146e-03, 2.6938e-01, 2.2784e-03, 6.3296e-04,
         2.4018e-03, 3.4303e-03, 1.8225e-03, 3.0221e-03, 1.1024e-03, 1.2905e-03,
         2.3675e-03, 8.4023e-04, 1.9230e-02, 1.3276e-03, 1.5551e-03, 1.6215e-03,
         3.2995e-03, 5.6422e-01],
        [7.2254e-04, 1.4324e-03, 1.5016e-03, 1.3492e-03, 1.0068e-01, 4.8319e-03,
         1.8946e-03, 3.0555e-03, 4.7146e-03, 2.6938e-01, 2.2784e-03, 6.3296e-04,
         2.4018e-03, 3.4303e-03, 1.8225e-03, 3.0221e-03, 1.1024e-03, 1.2905e-03,
         2.3675e-03, 8.4023e-04, 1.9230e-

 58%|███████████████████████▉                 | 105/180 [00:02<00:01, 38.24it/s]

y 256 tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26

 63%|█████████████████████████▋               | 113/180 [00:03<00:01, 38.07it/s]

tensor([[1.1027e-03, 4.4702e-04, 6.9737e-04, 7.2759e-04, 1.8138e-03, 6.5295e-02,
         7.7749e-04, 1.1611e-03, 2.5606e-01, 3.3011e-03, 7.0355e-04, 1.5382e-03,
         9.2926e-04, 3.0610e-04, 6.2486e-04, 1.3219e-03, 8.0034e-04, 8.7251e-04,
         8.8044e-03, 1.0367e-03, 9.9974e-02, 6.7512e-04, 1.4214e-03, 9.8139e-04,
         8.7365e-04, 5.4775e-01],
        [1.1350e-03, 4.8003e-04, 7.5519e-04, 7.8160e-04, 2.0333e-03, 6.5150e-02,
         8.2501e-04, 1.2957e-03, 2.2365e-01, 3.5764e-03, 7.8519e-04, 1.6595e-03,
         9.8535e-04, 3.5517e-04, 6.7393e-04, 1.3551e-03, 8.5102e-04, 9.3408e-04,
         1.0616e-02, 1.0253e-03, 1.2026e-01, 7.2770e-04, 1.5213e-03, 1.0112e-03,
         9.8293e-04, 5.5657e-01],
        [1.2673e-03, 5.3515e-04, 8.3521e-04, 8.4762e-04, 2.1358e-03, 6.7953e-02,
         9.2003e-04, 1.4764e-03, 2.2462e-01, 3.8320e-03, 8.5578e-04, 1.8434e-03,
         1.0910e-03, 3.9224e-04, 7.4651e-04, 1.5079e-03, 9.4873e-04, 1.0350e-03,
         1.1308e-02, 1.1492e-03, 1.1114e-

 67%|███████████████████████████▌             | 121/180 [00:03<00:01, 37.89it/s]

tensor([[5.7074e-03, 6.2526e-03, 4.6976e-03, 5.5768e-03, 4.9100e-03, 1.4819e-02,
         5.4329e-03, 2.2860e-02, 2.4127e-02, 4.2652e-03, 7.7463e-03, 1.3219e-01,
         7.2920e-03, 1.2671e-02, 5.6208e-03, 4.7573e-03, 4.5411e-03, 3.3214e-03,
         2.4046e-02, 4.5662e-03, 1.9693e-01, 1.8605e-02, 5.3415e-03, 4.4824e-03,
         1.6651e-02, 4.5259e-01],
        [5.9654e-03, 6.5006e-03, 4.9378e-03, 5.8191e-03, 5.0842e-03, 1.5335e-02,
         5.6746e-03, 2.3787e-02, 2.4419e-02, 4.5310e-03, 8.0620e-03, 1.3047e-01,
         7.5777e-03, 1.3008e-02, 5.8814e-03, 5.0052e-03, 4.7807e-03, 3.4947e-03,
         2.5025e-02, 4.8211e-03, 1.9275e-01, 1.9624e-02, 5.6128e-03, 4.7301e-03,
         1.7348e-02, 4.4976e-01],
        [5.5903e-03, 6.1343e-03, 4.6766e-03, 5.5282e-03, 4.6633e-03, 1.5226e-02,
         5.3675e-03, 2.2988e-02, 2.5052e-02, 4.1374e-03, 7.6572e-03, 1.3888e-01,
         7.1754e-03, 1.2456e-02, 5.5487e-03, 4.6967e-03, 4.4919e-03, 3.2675e-03,
         2.3436e-02, 4.6093e-03, 1.8750e-

 72%|█████████████████████████████▍           | 129/180 [00:03<00:01, 38.33it/s]

p tensor([0.0017, 0.0020, 0.0025, 0.0021, 0.0706, 0.0042, 0.0024, 0.0047, 0.0070,
        0.0205, 0.0026, 0.0015, 0.0034, 0.0030, 0.0018, 0.0036, 0.0019, 0.0020,
        0.0772, 0.0014, 0.7648, 0.0026, 0.0018, 0.0023, 0.0125],
       device='cuda:0') o_pos tensor([0.1142, 0.1208, 0.1298, 0.1241, 0.1325, 0.1129, 0.1410, 0.1583, 0.0761,
        0.1477, 0.1278, 0.0981, 0.1287, 0.0962, 0.1163, 0.1414, 0.1259, 0.1246,
        0.1208, 0.0699, 0.8810, 0.0605, 0.1482, 0.1352, 0.1166],
       device='cuda:0')
q.data tensor([[0.0020, 0.0016, 0.0009, 0.0012, 0.0131, 0.0041, 0.0011, 0.0022, 0.0066,
         0.0051, 0.0025, 0.0022, 0.0023, 0.0021, 0.0014, 0.0013, 0.0012, 0.0012,
         0.0197, 0.0008, 0.6164, 0.0015, 0.0016, 0.0010, 0.0028, 0.3040],
        [0.0019, 0.0015, 0.0008, 0.0011, 0.0122, 0.0039, 0.0010, 0.0020, 0.0063,
         0.0047, 0.0023, 0.0021, 0.0021, 0.0019, 0.0013, 0.0012, 0.0011, 0.0011,
         0.0190, 0.0008, 0.6253, 0.0014, 0.0014, 0.0009, 0.0026, 0.3000],
        [0.0067

 76%|███████████████████████████████▏         | 137/180 [00:03<00:01, 35.62it/s]

y 256 tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26

 81%|█████████████████████████████████        | 145/180 [00:03<00:00, 36.46it/s]

tensor([[4.2910e-06, 2.3442e-06, 3.8942e-06, 4.0340e-06, 8.1686e-05, 8.5353e-05,
         5.2337e-06, 2.4058e-06, 3.9778e-04, 5.0708e-05, 3.6435e-06, 5.7094e-06,
         4.0971e-06, 1.9979e-06, 2.2250e-06, 6.2100e-06, 3.0144e-06, 5.1176e-06,
         6.5789e-04, 7.8697e-07, 9.3683e-01, 7.2794e-07, 5.2495e-06, 3.9295e-06,
         8.7900e-06, 6.1827e-02],
        [1.1072e-03, 1.5112e-03, 1.2214e-03, 1.7774e-03, 2.3523e-03, 9.4687e-04,
         1.4741e-03, 4.2679e-01, 7.3095e-04, 2.6296e-03, 1.0220e-03, 4.4486e-03,
         8.9322e-03, 1.1607e-03, 1.0061e-03, 3.3403e-03, 1.9744e-03, 1.1791e-03,
         1.5095e-03, 1.0802e-03, 4.0905e-03, 4.3091e-03, 1.5097e-03, 1.3695e-03,
         6.0456e-02, 4.6207e-01],
        [1.8874e-04, 2.5881e-04, 2.4105e-04, 2.9454e-04, 4.6206e-01, 6.6613e-04,
         3.1139e-04, 9.3240e-04, 3.5486e-03, 3.8537e-03, 2.9723e-04, 1.1904e-04,
         5.8543e-04, 1.5722e-04, 2.4377e-04, 4.5315e-04, 1.6542e-04, 2.0314e-04,
         7.7179e-05, 3.0385e-04, 4.0458e-

 85%|██████████████████████████████████▊      | 153/180 [00:04<00:00, 37.13it/s]

tensor([0.0049, 0.0022, 0.0038, 0.0040, 0.0134, 0.1645, 0.0035, 0.0072, 0.4480,
        0.0125, 0.0044, 0.0062, 0.0044, 0.0029, 0.0038, 0.0065, 0.0037, 0.0043,
        0.0546, 0.0049, 0.2181, 0.0060, 0.0062, 0.0039, 0.0061],
       device='cuda:0') o_pos tensor([0.2297, 0.2100, 0.1880, 0.1863, 0.1372, 0.4003, 0.2260, 0.1735, 0.5423,
        0.2729, 0.1676, 0.2713, 0.2146, 0.1115, 0.1694, 0.2051, 0.2235, 0.2086,
        0.1825, 0.2071, 0.4897, 0.1169, 0.2373, 0.2518, 0.1498],
       device='cuda:0')
q.data tensor([[0.0046, 0.0054, 0.0055, 0.0054, 0.0111, 0.0783, 0.0046, 0.0047, 0.2225,
         0.0138, 0.0067, 0.0155, 0.0071, 0.0049, 0.0052, 0.0056, 0.0050, 0.0065,
         0.0204, 0.0074, 0.1144, 0.0068, 0.0064, 0.0054, 0.0066, 0.4200],
        [0.0043, 0.0063, 0.0040, 0.0063, 0.0363, 0.0082, 0.0056, 0.0175, 0.0068,
         0.0176, 0.0112, 0.0282, 0.0056, 0.1172, 0.0059, 0.0063, 0.0058, 0.0042,
         0.0099, 0.0085, 0.1528, 0.0153, 0.0036, 0.0036, 0.0180, 0.4910],
        [0.0073, 

 89%|████████████████████████████████████▋    | 161/180 [00:04<00:00, 37.20it/s]

tensor([20, 25, 20, 25, 25, 20, 25, 25, 25, 25, 25, 25, 25, 13, 25, 25, 20, 20,
        20, 20, 20, 20, 20, 25, 20, 20, 20, 25, 20, 20, 20, 25, 25, 25, 25, 25,
        25, 25, 25, 20, 20, 20, 20, 20, 20, 25, 20, 25, 20, 25, 25, 25, 25, 25,
        25, 25, 25, 20, 20, 20, 20, 20, 20, 20, 20, 20, 25, 25, 20, 20, 20, 20,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 20, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 20, 25, 25, 13, 25, 25, 25, 25,  4,
        13, 25, 25, 25, 25, 25, 25, 25, 25, 25, 20, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 20, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 20, 25, 25, 25, 25, 25, 25, 20,  4, 25, 25, 25, 20, 25, 25,
        20, 25, 20, 25, 25, 25, 25, 25, 

 92%|█████████████████████████████████████▌   | 165/180 [00:04<00:00, 37.38it/s]

tensor([[4.5873e-03, 6.3691e-03, 5.3969e-03, 5.3500e-03, 8.5179e-03, 6.6466e-03,
         4.0395e-03, 3.6718e-01, 5.8826e-03, 1.1569e-02, 5.0365e-03, 1.8239e-02,
         7.0160e-03, 8.2738e-03, 4.5994e-03, 6.5930e-03, 5.2988e-03, 3.2744e-03,
         1.3132e-02, 7.3078e-03, 1.3854e-02, 1.2895e-02, 4.9961e-03, 5.2431e-03,
         7.1641e-02, 3.8707e-01],
        [4.8885e-03, 6.7548e-03, 5.6760e-03, 5.6462e-03, 9.2408e-03, 7.0161e-03,
         4.2695e-03, 3.5568e-01, 6.2120e-03, 1.2317e-02, 5.2599e-03, 1.8357e-02,
         7.4645e-03, 8.6444e-03, 4.8682e-03, 6.9730e-03, 5.6187e-03, 3.5120e-03,
         1.3993e-02, 7.5535e-03, 1.5173e-02, 1.3170e-02, 5.3347e-03, 5.5664e-03,
         7.2291e-02, 3.8852e-01],
        [4.2862e-04, 9.1924e-04, 9.3234e-04, 9.8448e-04, 5.0780e-01, 4.0321e-03,
         5.0030e-04, 8.9627e-03, 1.8736e-02, 2.2833e-02, 5.7379e-04, 5.0955e-04,
         1.9860e-03, 2.6037e-03, 4.9410e-04, 1.4541e-03, 1.2268e-03, 1.1040e-03,
         4.2949e-04, 2.3746e-03, 9.1306e-

 96%|███████████████████████████████████████▍ | 173/180 [00:04<00:00, 34.75it/s]

tensor([25, 25, 25, 25, 25, 25, 25, 20, 25, 20, 25, 20, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 20, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,  4, 25, 20,
        25, 25, 25, 25, 25, 25,  4, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 20, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  4, 25, 25,
        25, 25, 25, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  4, 25, 20, 20,  4,
        25, 20, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25, 25, 25, 25, 20, 25, 25, 20, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  4, 20, 25, 20, 20, 20,
        20, 20, 20, 20, 20,  4, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
        25, 25, 25,  4, 25, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,  4, 25, 25,
        25, 25, 25, 25, 25, 25, 25, 25, 

 98%|████████████████████████████████████████▎| 177/180 [00:04<00:00, 33.57it/s]

p tensor([0.0018, 0.0025, 0.0024, 0.0014, 0.1160, 0.0022, 0.0023, 0.0047, 0.0051,
        0.0045, 0.0028, 0.0022, 0.0043, 0.0031, 0.0017, 0.0043, 0.0020, 0.0019,
        0.0064, 0.0020, 0.7783, 0.0031, 0.0016, 0.0023, 0.0412],
       device='cuda:0') o_pos tensor([0.1582, 0.1691, 0.1863, 0.1584, 0.2051, 0.1493, 0.1793, 0.1930, 0.1120,
        0.1741, 0.1674, 0.1678, 0.1555, 0.1345, 0.1512, 0.1823, 0.1762, 0.1698,
        0.1451, 0.1018, 0.8238, 0.0868, 0.1817, 0.1714, 0.1849],
       device='cuda:0')
q.data tensor([[4.6584e-03, 4.3210e-03, 2.7419e-03, 3.4396e-03, 3.0807e-02, 8.1941e-03,
         3.2152e-03, 6.6750e-03, 1.0902e-02, 7.7696e-03, 6.4386e-03, 7.0407e-03,
         5.6542e-03, 7.1871e-03, 3.7900e-03, 3.5085e-03, 3.6220e-03, 3.3039e-03,
         2.0067e-02, 2.8289e-03, 5.3137e-01, 3.5027e-03, 3.5200e-03, 2.4971e-03,
         8.3600e-03, 3.0459e-01],
        [4.3328e-03, 4.0221e-03, 2.5168e-03, 3.1828e-03, 3.0142e-02, 7.6834e-03,
         2.9585e-03, 6.2028e-03, 1.0498e-02, 7.2

100%|█████████████████████████████████████████| 180/180 [00:05<00:00, 35.83it/s]


y_true[closed_mask] [ 4  4  4 ... 24 24 24]
pred_hat_p[closed_mask] [4 4 4 ... 7 7 7]
o_acc_f_q_filtered 0.3886784470078474
o_acc_f_q_filtered_masked 0.4662524130720904
fscore 0.08228865532167008 f_hq 0.08394914487285027
#############################################################
 Closed Accuracy on Closed Test Data (p / hp): 91.29 / 91.29
 Open Accuracy on Full Test Data (q / hq):     36.71 / 36.06
 Open Accuracy on Extended Test Data (q / hq): 0.00 / 0.00
#############################################################

       prediction              ip  groundtruth            timestamp
0               4  170.187.156.85            4  2024-02-01 10:42:08
1               4  167.99.175.210            4  2024-02-01 14:26:28
2               4  167.99.175.210            4  2024-02-01 14:26:50
3               4  104.200.17.225            4  2024-02-01 22:28:05
4               4    45.56.77.151            4  2024-02-01 22:29:18
...           ...             ...          ...                  .

  checkpoint = torch.load(checkpoint_path)


Model at step 1999 loaded!


  2%|▉                                           | 1/49 [00:00<00:06,  7.18it/s]

y 256 tensor([ 4,  4,  4,  4, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 27, 27, 27, 27

 10%|████▍                                       | 5/49 [00:00<00:01, 23.76it/s]

tensor([6.3882e-04, 5.4280e-03, 2.2511e-03, 6.3512e-03, 5.5280e-04, 1.1907e-03,
        3.2069e-03, 4.8545e-03, 5.1222e-03, 5.2129e-03, 7.7216e-03, 2.2046e-04,
        4.1359e-04, 1.5782e-03, 1.8753e-03, 2.1996e-03, 6.2321e-06, 3.7107e-04,
        9.9766e-01, 2.2915e-03, 2.3427e-03, 4.2692e-03, 8.5484e-03, 7.9023e-05,
        1.7194e-03], device='cuda:0')
q.data tensor([[7.9819e-07, 2.3338e-07, 3.7498e-07, 1.6855e-07, 3.3506e-07, 7.2872e-06,
         1.3059e-05, 1.5420e-07, 7.0755e-05, 7.0423e-06, 1.6874e-07, 9.0286e-08,
         3.7186e-08, 5.9497e-06, 4.4772e-07, 2.2028e-07, 9.3963e-09, 1.8079e-07,
         9.8582e-01, 4.0449e-06, 1.3766e-05, 1.7820e-05, 2.4012e-07, 8.9384e-08,
         1.1656e-06, 1.4038e-02],
        [7.9819e-07, 2.3338e-07, 3.7498e-07, 1.6855e-07, 3.3506e-07, 7.2872e-06,
         1.3059e-05, 1.5420e-07, 7.0755e-05, 7.0423e-06, 1.6874e-07, 9.0286e-08,
         3.7186e-08, 5.9497e-06, 4.4772e-07, 2.2028e-07, 9.3963e-09, 1.8079e-07,
         9.8582e-01, 4.0449e-06, 1

 20%|████████▊                                  | 10/49 [00:00<00:01, 31.93it/s]

tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 

 31%|█████████████▏                             | 15/49 [00:00<00:00, 35.73it/s]

tensor([[4.2937e-03, 5.8481e-03, 7.0898e-03, 6.6231e-03, 1.4867e-02, 3.9772e-02,
         1.7057e-02, 6.4407e-03, 1.6203e-01, 5.5906e-03, 4.9608e-03, 5.2931e-03,
         3.7836e-03, 2.9643e-02, 6.9847e-03, 5.1958e-03, 9.5452e-03, 4.6501e-03,
         2.2504e-02, 3.2339e-02, 3.6723e-02, 2.6248e-02, 3.9661e-03, 4.8269e-03,
         1.8571e-02, 5.1515e-01],
        [1.3276e-03, 3.7379e-03, 2.4473e-03, 2.2375e-03, 1.2359e-02, 1.3978e-02,
         1.4800e-03, 1.9873e-03, 1.2890e-02, 1.9276e-02, 8.8006e-04, 2.3332e-03,
         1.9552e-03, 1.9396e-02, 2.3661e-03, 2.3397e-03, 5.3943e-01, 1.8813e-03,
         5.7725e-04, 1.0200e-02, 2.1036e-02, 9.7075e-03, 1.7396e-03, 2.4009e-03,
         1.3566e-02, 2.9847e-01],
        [1.3276e-03, 3.7379e-03, 2.4473e-03, 2.2375e-03, 1.2359e-02, 1.3978e-02,
         1.4800e-03, 1.9873e-03, 1.2890e-02, 1.9276e-02, 8.8006e-04, 2.3332e-03,
         1.9552e-03, 1.9396e-02, 2.3661e-03, 2.3397e-03, 5.3943e-01, 1.8813e-03,
         5.7725e-04, 1.0200e-02, 2.1036e-

 41%|█████████████████▌                         | 20/49 [00:00<00:00, 38.23it/s]

tensor([[8.2632e-03, 7.1272e-03, 6.2170e-03, 5.6340e-03, 7.0765e-03, 2.2375e-02,
         6.9928e-02, 9.0453e-03, 4.7943e-02, 9.4605e-02, 6.0489e-03, 6.7144e-03,
         5.2473e-03, 2.1186e-02, 6.9735e-03, 8.2237e-03, 7.4878e-03, 6.1135e-03,
         2.6722e-02, 1.4777e-02, 3.6605e-02, 1.6028e-02, 7.2664e-03, 5.7844e-03,
         7.0769e-03, 5.3953e-01],
        [1.8491e-03, 1.8431e-03, 2.1466e-03, 2.4876e-03, 3.6507e-03, 2.5933e-02,
         6.1257e-03, 1.9168e-03, 3.1410e-01, 2.4424e-03, 1.9227e-03, 1.8304e-03,
         1.6252e-03, 1.9030e-02, 2.1213e-03, 1.9189e-03, 1.1982e-03, 1.8657e-03,
         3.1869e-02, 3.2990e-03, 4.6832e-02, 4.0283e-03, 1.9236e-03, 1.7463e-03,
         5.3573e-03, 5.1094e-01],
        [8.2632e-03, 7.1272e-03, 6.2170e-03, 5.6340e-03, 7.0765e-03, 2.2375e-02,
         6.9928e-02, 9.0453e-03, 4.7943e-02, 9.4605e-02, 6.0489e-03, 6.7144e-03,
         5.2473e-03, 2.1186e-02, 6.9735e-03, 8.2237e-03, 7.4878e-03, 6.1135e-03,
         2.6722e-02, 1.4777e-02, 3.6605e-

 51%|█████████████████████▉                     | 25/49 [00:00<00:00, 39.75it/s]

y 256 tensor([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 26, 26, 26

 61%|██████████████████████████▎                | 30/49 [00:00<00:00, 40.59it/s]

tensor([[5.6575e-04, 1.6136e-03, 5.9831e-04, 1.0940e-03, 1.4257e-01, 2.0515e-03,
         3.2122e-04, 1.0142e-03, 9.7056e-03, 5.4780e-03, 1.0627e-03, 8.8422e-04,
         1.0711e-03, 3.1912e-03, 9.0627e-04, 6.1462e-04, 9.3840e-03, 5.7874e-04,
         8.7410e-04, 2.1982e-04, 3.1669e-03, 4.8879e-03, 1.1226e-03, 3.2471e-04,
         2.9861e-01, 5.0808e-01],
        [1.4789e-03, 1.8396e-03, 1.5423e-03, 1.7987e-03, 1.0458e-02, 3.3494e-02,
         1.0929e-03, 2.0788e-03, 3.0775e-01, 1.2597e-03, 1.6466e-03, 2.4133e-03,
         1.6283e-03, 2.0349e-02, 1.3331e-03, 2.0544e-03, 1.0165e-03, 2.0086e-03,
         2.6202e-03, 1.6705e-03, 6.0745e-02, 8.5695e-04, 1.5388e-03, 1.2200e-03,
         2.1204e-02, 5.1491e-01],
        [1.7062e-04, 4.0871e-04, 3.3840e-04, 6.2122e-04, 1.9987e-03, 3.7927e-04,
         3.6959e-03, 3.3703e-04, 3.6141e-04, 4.3183e-02, 4.5684e-04, 3.1741e-04,
         4.4944e-04, 5.6429e-04, 1.7094e-04, 4.1922e-04, 1.5836e-03, 1.6998e-04,
         2.4541e-03, 1.6862e-04, 4.4975e-

 71%|██████████████████████████████▋            | 35/49 [00:00<00:00, 41.22it/s]

tensor([[5.2574e-04, 1.4472e-03, 7.1254e-04, 1.6854e-03, 1.1401e-03, 2.5753e-03,
         2.0015e-03, 1.5871e-03, 5.6939e-03, 2.9763e-03, 1.5835e-03, 5.0770e-04,
         4.4444e-04, 1.3807e-03, 7.5696e-04, 9.8685e-04, 1.2485e-04, 5.8327e-04,
         3.1776e-01, 1.5869e-03, 2.3560e-03, 1.2066e-02, 1.5914e-03, 2.1775e-04,
         1.8963e-03, 6.3581e-01],
        [5.2574e-04, 1.4472e-03, 7.1254e-04, 1.6854e-03, 1.1401e-03, 2.5753e-03,
         2.0015e-03, 1.5871e-03, 5.6939e-03, 2.9763e-03, 1.5835e-03, 5.0770e-04,
         4.4444e-04, 1.3807e-03, 7.5696e-04, 9.8685e-04, 1.2485e-04, 5.8327e-04,
         3.1776e-01, 1.5869e-03, 2.3560e-03, 1.2066e-02, 1.5914e-03, 2.1775e-04,
         1.8963e-03, 6.3581e-01],
        [5.2574e-04, 1.4472e-03, 7.1254e-04, 1.6854e-03, 1.1401e-03, 2.5753e-03,
         2.0015e-03, 1.5871e-03, 5.6939e-03, 2.9763e-03, 1.5835e-03, 5.0770e-04,
         4.4444e-04, 1.3807e-03, 7.5696e-04, 9.8685e-04, 1.2485e-04, 5.8327e-04,
         3.1776e-01, 1.5869e-03, 2.3560e-

 82%|███████████████████████████████████        | 40/49 [00:01<00:00, 41.66it/s]

tensor([[4.4039e-03, 5.3214e-03, 4.9699e-03, 6.2438e-03, 9.2907e-03, 4.9968e-02,
         7.0669e-03, 5.7627e-03, 2.1658e-01, 6.3394e-03, 4.9474e-03, 5.7957e-03,
         5.3723e-03, 3.8524e-02, 5.0488e-03, 5.5960e-03, 6.5465e-03, 4.7429e-03,
         1.2201e-02, 7.5150e-03, 1.0808e-01, 7.5450e-03, 5.4210e-03, 5.3665e-03,
         1.0190e-02, 4.5116e-01],
        [1.8660e-03, 1.8563e-03, 1.2364e-03, 1.6874e-03, 3.0975e-03, 3.4604e-02,
         4.1493e-03, 1.9692e-03, 1.2164e-01, 4.6057e-03, 1.9240e-03, 1.8588e-03,
         1.9382e-03, 2.0785e-02, 1.9041e-03, 1.3932e-03, 2.0772e-03, 1.4666e-03,
         9.4249e-03, 4.7036e-03, 2.1259e-01, 5.2661e-03, 1.6249e-03, 1.1713e-03,
         2.6801e-03, 5.5248e-01],
        [5.4396e-03, 8.1235e-03, 6.6139e-03, 8.1995e-03, 1.1573e-02, 6.4631e-02,
         6.6081e-03, 7.7061e-03, 1.7108e-01, 8.1236e-03, 6.3557e-03, 8.4959e-03,
         7.9873e-03, 5.2133e-02, 6.7744e-03, 7.0341e-03, 1.8401e-02, 6.7173e-03,
         7.4230e-03, 1.0869e-02, 1.3441e-

 92%|███████████████████████████████████████▍   | 45/49 [00:01<00:00, 41.97it/s]

tensor([[0.0029, 0.0021, 0.0026, 0.0022, 0.0053, 0.0098, 0.0104, 0.0022, 0.0226,
         0.0117, 0.0016, 0.0016, 0.0010, 0.0096, 0.0027, 0.0025, 0.0013, 0.0015,
         0.3158, 0.0083, 0.0133, 0.0328, 0.0020, 0.0016, 0.0077, 0.5247],
        [0.0029, 0.0021, 0.0026, 0.0022, 0.0053, 0.0098, 0.0104, 0.0022, 0.0226,
         0.0117, 0.0016, 0.0016, 0.0010, 0.0096, 0.0027, 0.0025, 0.0013, 0.0015,
         0.3158, 0.0083, 0.0133, 0.0328, 0.0020, 0.0016, 0.0077, 0.5247],
        [0.0029, 0.0021, 0.0026, 0.0022, 0.0053, 0.0098, 0.0104, 0.0022, 0.0226,
         0.0117, 0.0016, 0.0016, 0.0010, 0.0096, 0.0027, 0.0025, 0.0013, 0.0015,
         0.3158, 0.0083, 0.0133, 0.0328, 0.0020, 0.0016, 0.0077, 0.5247],
        [0.0029, 0.0021, 0.0026, 0.0022, 0.0053, 0.0098, 0.0104, 0.0022, 0.0226,
         0.0117, 0.0016, 0.0016, 0.0010, 0.0096, 0.0027, 0.0025, 0.0013, 0.0015,
         0.3158, 0.0083, 0.0133, 0.0328, 0.0020, 0.0016, 0.0077, 0.5247],
        [0.0052, 0.0055, 0.0054, 0.0060, 0.0091, 0.0420,

100%|███████████████████████████████████████████| 49/49 [00:01<00:00, 36.66it/s]


y_true[closed_mask] [ 4  4  4 ... 24 24 24]
pred_hat_p[closed_mask] [24 24 24 ... 24 24 24]
o_acc_f_q_filtered 0.552133768562262
o_acc_f_q_filtered_masked 0.629056845485339
fscore 0.059811647937413004 f_hq 0.05988709053698215
#############################################################
 Closed Accuracy on Closed Test Data (p / hp): 88.18 / 88.18
 Open Accuracy on Full Test Data (q / hq):     51.27 / 51.27
 Open Accuracy on Extended Test Data (q / hq): 0.00 / 0.00
#############################################################





       prediction              ip  groundtruth            timestamp
0              25    45.33.95.124            4  2024-02-13 20:39:00
1              25   138.68.16.247            4  2024-02-13 21:02:19
2              25    64.227.12.35            4  2024-03-18 13:41:21
3              25     45.33.94.57            4  2024-03-18 14:03:43
4              25  162.142.125.90           27  2024-02-01 07:55:38
...           ...             ...          ...                  ...
12374          25  118.123.105.93           24  2024-02-17 21:05:15
12375          25   103.56.61.144           24  2024-02-29 02:24:27
12376          25   103.56.61.144           24  2024-03-24 20:17:51
12377          25   103.56.61.144           24  2024-03-24 20:18:18
12378          25  118.123.105.85           24  2024-04-06 10:45:59

[12379 rows x 4 columns]
Earliest time: 2024-01-31 16:00:00
Latest time: 2024-04-12 16:00:00
2024-01-31 16:00:00 2024-02-07 16:00:00
2024-02-07 16:00:00 2024-02-14 16:00:00
2024-02-14