In [23]:
from sklearn.model_selection import KFold
import pandas as pd

In [24]:
import wandb
wandb.login()

True

In [25]:
sweep_config = {
    'method': 'random'
}

In [26]:
metric = {
    'name': 'Grand Mean',
    'goal': 'maximize'
}

sweep_config['metric'] = metric

In [27]:
parameters_dict = {
    'dropout': {
          'values': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        },

    'learning_rate': {
        'distribution': 'uniform',
        'min': 0,
        'max': 0.1
        },

    'batch_size': {
        'values': [256, 512, 1024, 2048, 4096]
        },

    'data_augmentation_multiple': {
        'values': [1, 3, 5, 7]
        }
}

In [28]:
sweep_config['parameters'] = parameters_dict

In [29]:
sweep_id = wandb.sweep(sweep_config, project="pytorch-sweeps-sub_loc_train_only_smote")

Create sweep with ID: pi28mkaa
Sweep URL: https://wandb.ai/imucs/pytorch-sweeps-sub_loc_train_only_smote/sweeps/pi28mkaa


In [30]:
import pprint

pprint.pprint(sweep_config)

{'method': 'random',
 'metric': {'goal': 'maximize', 'name': 'Grand Mean'},
 'parameters': {'batch_size': {'values': [256, 512, 1024, 2048, 4096]},
                'data_augmentation_multiple': {'values': [1, 3, 5, 7]},
                'dropout': {'values': [0.1,
                                       0.2,
                                       0.3,
                                       0.4,
                                       0.5,
                                       0.6,
                                       0.7,
                                       0.8,
                                       0.9]},
                'learning_rate': {'distribution': 'uniform',
                                  'max': 0.1,
                                  'min': 0}}}


In [31]:
feature_pd = pd.read_csv('/home/kongge/projects/new_protT5/data/DPC_T5_578_right.csv')
labels_pd = pd.read_csv("/home/kongge/projects/new_protT5/data/mutil_label_578.csv")

In [32]:
from dataAug.tools import MLDA

In [33]:
from dataAug.all_tools import dataAugSMOTE

In [34]:

smote_multiple = {}

In [35]:
import time
from classify.targeTools import testThresholdFive, Accuracy
from torch import optim
from torch.utils.data import DataLoader, TensorDataset
import torch
from Classify_adjust import ModelClassify
def train_and_val(config=None):
    with wandb.init(config=config):
        config = wandb.config
        kf = KFold(n_splits=10, shuffle=True)
        model_discord = []
        for train_index, test_index in kf.split(feature_pd):
            train_true_data = feature_pd.iloc[train_index]
            train_true_label = labels_pd.iloc[train_index]

            G_feature, G_label = dataAugSMOTE(train_true_data, train_true_label, config.data_augmentation_multiple, 1424)
            train_feature = pd.concat([train_true_data, G_feature], axis=0)
            train_label = pd.concat([train_true_label, G_label], axis=0)

            test_data = feature_pd.iloc[test_index]
            test_label = labels_pd.iloc[test_index]

            datasetTrain = TensorDataset(torch.tensor(train_feature.values), torch.tensor(train_label.values))
            batch_size = config.batch_size
            dataloaderTrain = DataLoader(datasetTrain, batch_size=batch_size, shuffle=True)

            datasetTest = TensorDataset(torch.tensor(test_data.values), torch.tensor(test_label.values))

            batch_size = len(datasetTest)
            dataloaderTest = DataLoader(datasetTest, batch_size=batch_size, shuffle=False)

            model = ModelClassify(drop_rate=config.dropout, num_class=5, feature_num=1424)
            criterion = torch.nn.BCELoss()
            optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
            import warnings
            warnings.filterwarnings("ignore")
            GMList = {}
            epochs = 200
            for epoch in range(epochs):
                model.train()
                total_loss = 0.0
                for idx, data in enumerate(dataloaderTrain, 0):
                    inputs, labels = data
                    labels = labels.float()
                    inputs = inputs.float()
                    out = model(inputs)
                    loss = criterion(out, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                avg_loss = total_loss / len(dataloaderTrain)
                threshold = 0.5
                labels_cov = torch.where(out > threshold, torch.tensor(1), torch.tensor(0))
                print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.4f}, ACC: {Accuracy(labels.int(), labels_cov)}")
                GMScore1 = testThresholdFive(epoch, model, dataloaderTest, class_num=5)
                GMList[epoch] = GMScore1
            t = time.time()
            best_key = max(GMList, key=GMList.get)
            best_value = GMList[best_key]
            best_value = [x.item() if isinstance(x, torch.Tensor) else x for x in best_value]
            model_discord.append(best_value)
        model_discord_column_means = [sum(col) / len(col) for col in zip(*model_discord)]
        smote_multiple[int(t)] = model_discord_column_means
        wandb.log({"Grand Mean": model_discord_column_means[0]})

In [36]:
wandb.agent(sweep_id, train_and_val, count=20)

[34m[1mwandb[0m: Agent Starting Run: 24m4jj5b with config:
[34m[1mwandb[0m: 	batch_size: 2048
[34m[1mwandb[0m: 	data_augmentation_multiple: 3
[34m[1mwandb[0m: 	dropout: 0.8
[34m[1mwandb[0m: 	learning_rate: 0.013940153742164264


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669207454348602, max=1.0…

Epoch [1/200], Average Loss: 0.7073, ACC: 0.1845614035087719
epoch:0, bestThreshold:[0.13, 0.83, 0.25, 0.34, 0.62], GM:0.42862066626548767, OAA:0.034482758620689655, ACC:0.38505747126436785, F1:0.51666659116745
Epoch [2/200], Average Loss: 0.6659, ACC: 0.3385964912280702
epoch:1, bestThreshold:[0.01, 0.97, 0.24, 0.23, 0.01], GM:0.39465516805648804, OAA:0.034482758620689655, ACC:0.34482758620689646, F1:0.46896544098854065
Epoch [3/200], Average Loss: 0.6051, ACC: 0.3535087719298245
epoch:2, bestThreshold:[0.5, 0.98, 0.97, 0.05, 0.97], GM:0.5247126817703247, OAA:0.3448275862068966, ACC:0.5057471264367815, F1:0.5660918951034546
Epoch [4/200], Average Loss: 0.5103, ACC: 0.4263157894736842
epoch:3, bestThreshold:[0.01, 0.86, 0.45, 0.98, 0.34], GM:0.46264368295669556, OAA:0.25862068965517243, ACC:0.43534482758620685, F1:0.5057470202445984
Epoch [5/200], Average Loss: 0.4494, ACC: 0.4350877192982456
epoch:4, bestThreshold:[0.01, 0.97, 0.45, 0.43, 0.97], GM:0.5527585744857788, OAA:0.4310344827

0,1
Grand Mean,▁

0,1
Grand Mean,0.86843


[34m[1mwandb[0m: Agent Starting Run: t28bqm4v with config:
[34m[1mwandb[0m: 	batch_size: 2048
[34m[1mwandb[0m: 	data_augmentation_multiple: 7
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	learning_rate: 0.06482808769484193


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666851753058533, max=1.0)…

Epoch [1/200], Average Loss: 0.9540, ACC: 0.2867298578199054
epoch:0, bestThreshold:[0.01, 0.01, 0.44, 0.03, 0.01], GM:0.28091952204704285, OAA:0.0, ACC:0.20833333333333331, F1:0.3155171573162079
Epoch [2/200], Average Loss: 0.4553, ACC: 0.5213270142180095
epoch:1, bestThreshold:[0.01, 0.84, 0.02, 0.01, 0.55], GM:0.4505172371864319, OAA:0.27586206896551724, ACC:0.4224137931034483, F1:0.4839079678058624
Epoch [3/200], Average Loss: 0.2653, ACC: 0.6611374407582938
epoch:2, bestThreshold:[0.01, 0.29, 0.01, 0.81, 0.98], GM:0.5126436948776245, OAA:0.39655172413793105, ACC:0.5028735632183907, F1:0.5402297973632812
Epoch [4/200], Average Loss: 0.2375, ACC: 0.7630331753554502
epoch:3, bestThreshold:[0.16, 0.94, 0.98, 0.79, 0.01], GM:0.559195339679718, OAA:0.5517241379310345, ACC:0.5574712643678161, F1:0.5603448152542114
Epoch [5/200], Average Loss: 0.2312, ACC: 0.6824644549763034
epoch:4, bestThreshold:[0.02, 0.9, 0.25, 0.98, 0.01], GM:0.6356321573257446, OAA:0.5689655172413793, ACC:0.62643678

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.89548


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: p5m6pv5w with config:
[34m[1mwandb[0m: 	batch_size: 2048
[34m[1mwandb[0m: 	data_augmentation_multiple: 1
[34m[1mwandb[0m: 	dropout: 0.9
[34m[1mwandb[0m: 	learning_rate: 0.01386928187565113


Epoch [1/200], Average Loss: 0.7378, ACC: 0.19907232704402547
epoch:0, bestThreshold:[0.01, 0.91, 0.64, 0.96, 0.25], GM:0.4568965435028076, OAA:0.017241379310344827, ACC:0.41522988505747127, F1:0.5574710965156555
Epoch [2/200], Average Loss: 0.7071, ACC: 0.21037735849056668
epoch:1, bestThreshold:[0.03, 0.08, 0.81, 0.05, 0.69], GM:0.42965516448020935, OAA:0.034482758620689655, ACC:0.34683908045976997, F1:0.4890803396701813
Epoch [3/200], Average Loss: 0.6951, ACC: 0.19455974842767337
epoch:2, bestThreshold:[0.03, 0.01, 0.8, 0.02, 0.04], GM:0.4082348346710205, OAA:0.017241379310344827, ACC:0.3123563218390804, F1:0.4587027132511139
Epoch [4/200], Average Loss: 0.6685, ACC: 0.2122798742138368
epoch:3, bestThreshold:[0.01, 0.56, 0.94, 0.69, 0.05], GM:0.4744827151298523, OAA:0.05172413793103448, ACC:0.4235632183908045, F1:0.5655171871185303
Epoch [5/200], Average Loss: 0.6368, ACC: 0.19635220125786168
epoch:4, bestThreshold:[0.03, 0.93, 0.87, 0.06, 0.02], GM:0.4289655089378357, OAA:0.051724

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.76741


[34m[1mwandb[0m: Agent Starting Run: wmkiorky with config:
[34m[1mwandb[0m: 	batch_size: 1024
[34m[1mwandb[0m: 	data_augmentation_multiple: 7
[34m[1mwandb[0m: 	dropout: 0.4
[34m[1mwandb[0m: 	learning_rate: 0.06396071724645769


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668705766399703, max=1.0…

Epoch [1/200], Average Loss: 0.7281, ACC: 0.5099388379204893
epoch:0, bestThreshold:[0.96, 0.97, 0.01, 0.24, 0.01], GM:0.6096552014350891, OAA:0.41379310344827586, ACC:0.5882183908045976, F1:0.6534482836723328
Epoch [2/200], Average Loss: 0.3406, ACC: 0.6085626911314985
epoch:1, bestThreshold:[0.01, 0.5, 0.88, 0.89, 0.91], GM:0.7114942669868469, OAA:0.6379310344827587, ACC:0.7068965517241379, F1:0.7298850417137146
Epoch [3/200], Average Loss: 0.2465, ACC: 0.6467889908256881
epoch:2, bestThreshold:[0.2, 0.95, 0.1, 0.93, 0.36], GM:0.8195402026176453, OAA:0.8103448275862069, ACC:0.8189655172413793, F1:0.8218390345573425
Epoch [4/200], Average Loss: 0.2001, ACC: 0.7522935779816514
epoch:3, bestThreshold:[0.26, 0.16, 0.06, 0.14, 0.76], GM:0.8632184267044067, OAA:0.8448275862068966, ACC:0.8620689655172413, F1:0.8678160905838013
Epoch [5/200], Average Loss: 0.1685, ACC: 0.7431192660550459
epoch:4, bestThreshold:[0.02, 0.98, 0.01, 0.1, 0.98], GM:0.8356322050094604, OAA:0.7758620689655172, ACC:

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.90148


[34m[1mwandb[0m: Agent Starting Run: rgux52hh with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 5
[34m[1mwandb[0m: 	dropout: 0.8
[34m[1mwandb[0m: 	learning_rate: 0.09341457606853688


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668759224315485, max=1.0…

Epoch [1/200], Average Loss: 0.7273, ACC: 0.19810335917312685
epoch:0, bestThreshold:[0.01, 0.17, 0.01, 0.01, 0.01], GM:0.3568965494632721, OAA:0.22413793103448276, ACC:0.3448275862068966, F1:0.3879309892654419
Epoch [2/200], Average Loss: 0.8911, ACC: 0.19096124031007738
epoch:1, bestThreshold:[0.01, 0.84, 0.01, 0.01, 0.01], GM:0.3327586054801941, OAA:0.22413793103448276, ACC:0.32471264367816094, F1:0.35919541120529175
Epoch [3/200], Average Loss: 0.6116, ACC: 0.1668475452196382
epoch:2, bestThreshold:[0.92, 0.01, 0.31, 0.01, 0.01], GM:0.4039655327796936, OAA:0.3620689655172414, ACC:0.40086206896551724, F1:0.41494253277778625
Epoch [4/200], Average Loss: 0.4990, ACC: 0.19923514211886306
epoch:3, bestThreshold:[0.01, 0.01, 0.01, 0.01, 0.01], GM:0.4333333373069763, OAA:0.39655172413793105, ACC:0.43103448275862066, F1:0.44252872467041016
Epoch [5/200], Average Loss: 0.4805, ACC: 0.3576485788113695
epoch:4, bestThreshold:[0.01, 0.01, 0.01, 0.01, 0.01], GM:0.4402298927307129, OAA:0.4310344

0,1
Grand Mean,▁

0,1
Grand Mean,0.88475


[34m[1mwandb[0m: Agent Starting Run: uvb2ny6e with config:
[34m[1mwandb[0m: 	batch_size: 512
[34m[1mwandb[0m: 	data_augmentation_multiple: 3
[34m[1mwandb[0m: 	dropout: 0.6
[34m[1mwandb[0m: 	learning_rate: 0.02194015029855877


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669129456082978, max=1.0…

Epoch [1/200], Average Loss: 0.6299, ACC: 0.4782608695652174
epoch:0, bestThreshold:[0.97, 0.98, 0.9, 0.01, 0.08], GM:0.42500001192092896, OAA:0.15517241379310345, ACC:0.38850574712643676, F1:0.4816091060638428
Epoch [2/200], Average Loss: 0.3250, ACC: 0.5543478260869565
epoch:1, bestThreshold:[0.05, 0.04, 0.1, 0.59, 0.78], GM:0.5559769868850708, OAA:0.46551724137931033, ACC:0.5445402298850575, F1:0.5758620500564575
Epoch [3/200], Average Loss: 0.2060, ACC: 0.7880434782608695
epoch:2, bestThreshold:[0.01, 0.95, 0.39, 0.7, 0.79], GM:0.5804597735404968, OAA:0.5344827586206896, ACC:0.5775862068965517, F1:0.5919540524482727
Epoch [4/200], Average Loss: 0.1709, ACC: 0.7934782608695652
epoch:3, bestThreshold:[0.13, 0.47, 0.04, 0.01, 0.98], GM:0.6494253277778625, OAA:0.603448275862069, ACC:0.646551724137931, F1:0.6609196662902832
Epoch [5/200], Average Loss: 0.1631, ACC: 0.7916666666666666
epoch:4, bestThreshold:[0.08, 0.47, 0.03, 0.02, 0.85], GM:0.6977011561393738, OAA:0.6551724137931034, AC

0,1
Grand Mean,▁

0,1
Grand Mean,0.90101


[34m[1mwandb[0m: Agent Starting Run: 4g4ct8xr with config:
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	data_augmentation_multiple: 7
[34m[1mwandb[0m: 	dropout: 0.6
[34m[1mwandb[0m: 	learning_rate: 0.09171355037311196


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668465174734592, max=1.0…

Epoch [1/200], Average Loss: 0.4491, ACC: 0.4873096446700508
epoch:0, bestThreshold:[0.02, 0.91, 0.07, 0.01, 0.4], GM:0.6574712991714478, OAA:0.5344827586206896, ACC:0.6494252873563219, F1:0.6896551847457886
Epoch [2/200], Average Loss: 0.2735, ACC: 0.6700507614213198
epoch:1, bestThreshold:[0.82, 0.12, 0.01, 0.02, 0.48], GM:0.6982758641242981, OAA:0.5862068965517241, ACC:0.6896551724137931, F1:0.7270115613937378
Epoch [3/200], Average Loss: 0.2052, ACC: 0.7081218274111675
epoch:2, bestThreshold:[0.23, 0.12, 0.1, 0.03, 0.97], GM:0.7954023480415344, OAA:0.7586206896551724, ACC:0.7931034482758621, F1:0.8045977354049683
Epoch [4/200], Average Loss: 0.1594, ACC: 0.8299492385786802
epoch:3, bestThreshold:[0.2, 0.08, 0.38, 0.54, 0.42], GM:0.7540229558944702, OAA:0.6896551724137931, ACC:0.75, F1:0.7701149582862854
Epoch [5/200], Average Loss: 0.1297, ACC: 0.8604060913705583
epoch:4, bestThreshold:[0.36, 0.79, 0.4, 0.04, 0.92], GM:0.794252872467041, OAA:0.7758620689655172, ACC:0.79310344827586

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.89161


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 9bivyutu with config:
[34m[1mwandb[0m: 	batch_size: 2048
[34m[1mwandb[0m: 	data_augmentation_multiple: 1
[34m[1mwandb[0m: 	dropout: 0.9
[34m[1mwandb[0m: 	learning_rate: 0.02901406039658294


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01667133462615311, max=1.0)…

Epoch [1/200], Average Loss: 0.7307, ACC: 0.20822347771500332
epoch:0, bestThreshold:[0.03, 0.94, 0.98, 0.86, 0.75], GM:0.4249425530433655, OAA:0.1724137931034483, ACC:0.3922413793103448, F1:0.47816088795661926
Epoch [2/200], Average Loss: 0.7051, ACC: 0.19022284996861305
epoch:1, bestThreshold:[0.97, 0.28, 0.85, 0.17, 0.11], GM:0.3759769797325134, OAA:0.1724137931034483, ACC:0.3382183908045977, F1:0.41034480929374695
Epoch [3/200], Average Loss: 0.6563, ACC: 0.20723477715003175
epoch:2, bestThreshold:[0.03, 0.02, 0.96, 0.26, 0.22], GM:0.32252877950668335, OAA:0.10344827586206896, ACC:0.28850574712643673, F1:0.36321839690208435
Epoch [4/200], Average Loss: 0.5907, ACC: 0.1898932831136222
epoch:3, bestThreshold:[0.89, 0.09, 0.02, 0.41, 0.11], GM:0.42563214898109436, OAA:0.1724137931034483, ACC:0.3810344827586207, F1:0.46954014897346497
Epoch [5/200], Average Loss: 0.5261, ACC: 0.1284839924670433
epoch:4, bestThreshold:[0.03, 0.75, 0.97, 0.11, 0.01], GM:0.38620689511299133, OAA:0.1034482

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.73147


[34m[1mwandb[0m: Agent Starting Run: wtrpse4v with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 5
[34m[1mwandb[0m: 	dropout: 0.5
[34m[1mwandb[0m: 	learning_rate: 0.07534863941952263


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670268727466464, max=1.0…

Epoch [1/200], Average Loss: 0.7358, ACC: 0.19546955624355022
epoch:0, bestThreshold:[0.01, 0.47, 0.01, 0.01, 0.01], GM:0.384482741355896, OAA:0.034482758620689655, ACC:0.3563218390804598, F1:0.46839067339897156
Epoch [2/200], Average Loss: 1.0000, ACC: 0.17091847265221796
epoch:1, bestThreshold:[0.01, 0.9, 0.01, 0.01, 0.01], GM:0.24770113825798035, OAA:0.034482758620689655, ACC:0.22413793103448276, F1:0.29597702622413635
Epoch [3/200], Average Loss: 1.2054, ACC: 0.2089267285861714
epoch:2, bestThreshold:[0.03, 0.98, 0.01, 0.01, 0.92], GM:0.3641379475593567, OAA:0.1724137931034483, ACC:0.3333333333333333, F1:0.4011493921279907
Epoch [4/200], Average Loss: 0.5023, ACC: 0.3880908152734776
epoch:3, bestThreshold:[0.01, 0.5, 0.95, 0.91, 0.77], GM:0.42367810010910034, OAA:0.29310344827586204, ACC:0.403735632183908, F1:0.44885051250457764
Epoch [5/200], Average Loss: 0.3875, ACC: 0.4356037151702785
epoch:4, bestThreshold:[0.01, 0.33, 0.02, 0.9, 0.01], GM:0.3968965411186218, OAA:0.24137931034

VBox(children=(Label(value='0.008 MB of 0.414 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.020468…

0,1
Grand Mean,▁

0,1
Grand Mean,0.89451


[34m[1mwandb[0m: Agent Starting Run: adh19fvp with config:
[34m[1mwandb[0m: 	batch_size: 2048
[34m[1mwandb[0m: 	data_augmentation_multiple: 7
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	learning_rate: 0.09334204080967944


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670222425212464, max=1.0…

Epoch [1/200], Average Loss: 1.2695, ACC: 0.22935779816513768
epoch:0, bestThreshold:[0.01, 0.64, 0.02, 0.29, 0.36], GM:0.25735631585121155, OAA:0.05172413793103448, ACC:0.21408045977011494, F1:0.2896551489830017
Epoch [2/200], Average Loss: 0.5702, ACC: 0.24311926605504589
epoch:1, bestThreshold:[0.02, 0.71, 0.01, 0.01, 0.92], GM:0.5231034159660339, OAA:0.39655172413793105, ACC:0.507183908045977, F1:0.5494252443313599
Epoch [3/200], Average Loss: 0.3641, ACC: 0.5527522935779816
epoch:2, bestThreshold:[0.3, 0.33, 0.94, 0.01, 0.63], GM:0.522988498210907, OAA:0.39655172413793105, ACC:0.5086206896551724, F1:0.5517240762710571
Epoch [4/200], Average Loss: 0.3189, ACC: 0.602446483180428
epoch:3, bestThreshold:[0.01, 0.91, 0.98, 0.01, 0.98], GM:0.5905747413635254, OAA:0.43103448275862066, ACC:0.5632183908045976, F1:0.619540274143219
Epoch [5/200], Average Loss: 0.2882, ACC: 0.5879204892966362
epoch:4, bestThreshold:[0.82, 0.97, 0.98, 0.97, 0.54], GM:0.7396551966667175, OAA:0.6206896551724138

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.90604


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: uv32qh1s with config:
[34m[1mwandb[0m: 	batch_size: 1024
[34m[1mwandb[0m: 	data_augmentation_multiple: 7
[34m[1mwandb[0m: 	dropout: 0.4
[34m[1mwandb[0m: 	learning_rate: 0.08702901630379525


Epoch [1/200], Average Loss: 0.7983, ACC: 0.3333333333333333
epoch:0, bestThreshold:[0.84, 0.01, 0.01, 0.01, 0.01], GM:0.4000000059604645, OAA:0.3448275862068966, ACC:0.39655172413793105, F1:0.4137931168079376
Epoch [2/200], Average Loss: 0.3481, ACC: 0.5073529411764706
epoch:1, bestThreshold:[0.04, 0.24, 0.45, 0.72, 0.04], GM:0.5816091299057007, OAA:0.5172413793103449, ACC:0.5775862068965517, F1:0.5977011919021606
Epoch [3/200], Average Loss: 0.2741, ACC: 0.6102941176470589
epoch:2, bestThreshold:[0.39, 0.76, 0.06, 0.03, 0.5], GM:0.7074712514877319, OAA:0.6724137931034483, ACC:0.7040229885057472, F1:0.7155172824859619
Epoch [4/200], Average Loss: 0.2346, ACC: 0.7352941176470589
epoch:3, bestThreshold:[0.08, 0.44, 0.01, 0.01, 0.84], GM:0.7287356853485107, OAA:0.6551724137931034, ACC:0.7241379310344828, F1:0.7471265196800232
Epoch [5/200], Average Loss: 0.2081, ACC: 0.7401960784313726
epoch:4, bestThreshold:[0.32, 0.09, 0.27, 0.06, 0.97], GM:0.8367816209793091, OAA:0.8275862068965517, A

VBox(children=(Label(value='0.008 MB of 0.413 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.020517…

0,1
Grand Mean,▁

0,1
Grand Mean,0.90133


[34m[1mwandb[0m: Agent Starting Run: qlkm897j with config:
[34m[1mwandb[0m: 	batch_size: 2048
[34m[1mwandb[0m: 	data_augmentation_multiple: 3
[34m[1mwandb[0m: 	dropout: 0.6
[34m[1mwandb[0m: 	learning_rate: 0.07468746258241482


Epoch [1/200], Average Loss: 0.7585, ACC: 0.20526315789473681
epoch:0, bestThreshold:[0.96, 0.01, 0.01, 0.01, 0.01], GM:0.38793104887008667, OAA:0.1896551724137931, ACC:0.367816091954023, F1:0.43390804529190063
Epoch [2/200], Average Loss: 0.5322, ACC: 0.44824561403508767
epoch:1, bestThreshold:[0.23, 0.01, 0.01, 0.01, 0.01], GM:0.40626436471939087, OAA:0.22413793103448276, ACC:0.3764367816091953, F1:0.44080454111099243
Epoch [3/200], Average Loss: 0.4425, ACC: 0.3894736842105263
epoch:2, bestThreshold:[0.01, 0.92, 0.84, 0.01, 0.56], GM:0.4316091537475586, OAA:0.39655172413793105, ACC:0.42816091954022995, F1:0.4396551847457886
Epoch [4/200], Average Loss: 0.3385, ACC: 0.5789473684210527
epoch:3, bestThreshold:[0.96, 0.97, 0.64, 0.01, 0.89], GM:0.5264368057250977, OAA:0.5172413793103449, ACC:0.5258620689655172, F1:0.5287356376647949
Epoch [5/200], Average Loss: 0.3685, ACC: 0.5298245614035088
epoch:4, bestThreshold:[0.01, 0.01, 0.71, 0.03, 0.12], GM:0.5798851251602173, OAA:0.51724137931

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.88817


[34m[1mwandb[0m: Agent Starting Run: pjuu0ruq with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 7
[34m[1mwandb[0m: 	dropout: 0.8
[34m[1mwandb[0m: 	learning_rate: 0.024378019351907462


Epoch [1/200], Average Loss: 0.7179, ACC: 0.18143759873617699
epoch:0, bestThreshold:[0.11, 0.86, 0.49, 0.22, 0.96], GM:0.4452873170375824, OAA:0.13793103448275862, ACC:0.4224137931034483, F1:0.5195400714874268
Epoch [2/200], Average Loss: 0.6395, ACC: 0.22922590837282783
epoch:1, bestThreshold:[0.01, 0.76, 0.97, 0.67, 0.87], GM:0.4941379129886627, OAA:0.25862068965517243, ACC:0.47040229885057466, F1:0.5471263527870178
Epoch [3/200], Average Loss: 0.4920, ACC: 0.2985781990521327
epoch:2, bestThreshold:[0.02, 0.02, 0.98, 0.58, 0.04], GM:0.4405747354030609, OAA:0.27586206896551724, ACC:0.41666666666666663, F1:0.472988486289978
Epoch [4/200], Average Loss: 0.3978, ACC: 0.33886255924170616
epoch:3, bestThreshold:[0.05, 0.9, 0.3, 0.01, 0.97], GM:0.5189655423164368, OAA:0.41379310344827586, ACC:0.5086206896551724, F1:0.5431033968925476
Epoch [5/200], Average Loss: 0.3796, ACC: 0.3175355450236967
epoch:4, bestThreshold:[0.04, 0.03, 0.01, 0.02, 0.97], GM:0.6902298927307129, OAA:0.6551724137931

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.88557


[34m[1mwandb[0m: Agent Starting Run: 5uh61txy with config:
[34m[1mwandb[0m: 	batch_size: 2048
[34m[1mwandb[0m: 	data_augmentation_multiple: 5
[34m[1mwandb[0m: 	dropout: 0.7
[34m[1mwandb[0m: 	learning_rate: 0.0823677590279994


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670653394733868, max=1.0…

Epoch [1/200], Average Loss: 0.8013, ACC: 0.1846251053074985
epoch:0, bestThreshold:[0.01, 0.6, 0.01, 0.01, 0.01], GM:0.27160918712615967, OAA:0.08620689655172414, ACC:0.2413793103448276, F1:0.3063218295574188
Epoch [2/200], Average Loss: 0.5607, ACC: 0.24880651502386972
epoch:1, bestThreshold:[0.01, 0.01, 0.01, 0.01, 0.05], GM:0.5111494064331055, OAA:0.2413793103448276, ACC:0.48994252873563215, F1:0.5758620500564575
Epoch [3/200], Average Loss: 0.4151, ACC: 0.3928671721426566
epoch:2, bestThreshold:[0.02, 0.03, 0.01, 0.01, 0.97], GM:0.6022988557815552, OAA:0.4827586206896552, ACC:0.5948275862068966, F1:0.6321837902069092
Epoch [4/200], Average Loss: 0.3801, ACC: 0.4243190115136197
epoch:3, bestThreshold:[0.01, 0.96, 0.14, 0.16, 0.47], GM:0.615517258644104, OAA:0.5344827586206896, ACC:0.6091954022988506, F1:0.6350574493408203
Epoch [5/200], Average Loss: 0.3454, ACC: 0.4854675652906487
epoch:4, bestThreshold:[0.01, 0.98, 0.44, 0.01, 0.93], GM:0.6720689535140991, OAA:0.5862068965517241,

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.89731


[34m[1mwandb[0m: Agent Starting Run: nrmnl57p with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 5
[34m[1mwandb[0m: 	dropout: 0.1
[34m[1mwandb[0m: 	learning_rate: 0.02253264787935998


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01667125046563645, max=1.0)…

Epoch [1/200], Average Loss: 0.7417, ACC: 0.1778849144634522
epoch:0, bestThreshold:[0.44, 0.05, 0.01, 0.03, 0.01], GM:0.3829309940338135, OAA:0.0, ACC:0.30804597701149433, F1:0.44425269961357117
Epoch [2/200], Average Loss: 1.0097, ACC: 0.23430274753758565
epoch:1, bestThreshold:[0.98, 0.01, 0.01, 0.98, 0.04], GM:0.3839819133281708, OAA:0.0, ACC:0.2709770114942529, F1:0.4196223020553589
Epoch [3/200], Average Loss: 0.8950, ACC: 0.16741835147744902
epoch:2, bestThreshold:[0.06, 0.98, 0.05, 0.4, 0.3], GM:0.3480459451675415, OAA:0.0, ACC:0.27212643678160925, F1:0.40057459473609924
Epoch [4/200], Average Loss: 0.5612, ACC: 0.40166925868325476
epoch:3, bestThreshold:[0.11, 0.86, 0.02, 0.91, 0.08], GM:0.4670114517211914, OAA:0.1896551724137931, ACC:0.3919540229885057, F1:0.497700959444046
Epoch [5/200], Average Loss: 0.5475, ACC: 0.4339502332814909
epoch:4, bestThreshold:[0.97, 0.98, 0.97, 0.88, 0.96], GM:0.4852873384952545, OAA:0.22413793103448276, ACC:0.4330459770114942, F1:0.529310226440

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.88663


[34m[1mwandb[0m: Agent Starting Run: 7gdv4su0 with config:
[34m[1mwandb[0m: 	batch_size: 1024
[34m[1mwandb[0m: 	data_augmentation_multiple: 7
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	learning_rate: 0.03788812435493331


Epoch [1/200], Average Loss: 0.7044, ACC: 0.6232227488151659
epoch:0, bestThreshold:[0.69, 0.96, 0.15, 0.01, 0.96], GM:0.40833330154418945, OAA:0.05172413793103448, ACC:0.34770114942528746, F1:0.4741378426551819
Epoch [2/200], Average Loss: 0.2668, ACC: 0.6113744075829384
epoch:1, bestThreshold:[0.04, 0.13, 0.06, 0.89, 0.98], GM:0.7080459594726562, OAA:0.5862068965517241, ACC:0.6954022988505746, F1:0.7356321811676025
Epoch [3/200], Average Loss: 0.1913, ACC: 0.8175355450236966
epoch:2, bestThreshold:[0.01, 0.18, 0.02, 0.08, 0.8], GM:0.7448276281356812, OAA:0.5517241379310345, ACC:0.7327586206896551, F1:0.7931035757064819
Epoch [4/200], Average Loss: 0.1536, ACC: 0.7867298578199052
epoch:3, bestThreshold:[0.35, 0.03, 0.01, 0.01, 0.89], GM:0.7360919713973999, OAA:0.5689655172413793, ACC:0.7270114942528735, F1:0.7781609892845154
Epoch [5/200], Average Loss: 0.1163, ACC: 0.8838862559241706
epoch:4, bestThreshold:[0.04, 0.17, 0.16, 0.01, 0.87], GM:0.8908045887947083, OAA:0.8448275862068966,

VBox(children=(Label(value='0.008 MB of 0.414 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.020378…

0,1
Grand Mean,▁

0,1
Grand Mean,0.90573


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: hb9w29n2 with config:
[34m[1mwandb[0m: 	batch_size: 1024
[34m[1mwandb[0m: 	data_augmentation_multiple: 3
[34m[1mwandb[0m: 	dropout: 0.8
[34m[1mwandb[0m: 	learning_rate: 0.07272088617758642


Epoch [1/200], Average Loss: 0.7115, ACC: 0.22448979591836735
epoch:0, bestThreshold:[0.01, 0.47, 0.01, 0.01, 0.02], GM:0.3285057544708252, OAA:0.13793103448275862, ACC:0.3017241379310345, F1:0.36666661500930786
Epoch [2/200], Average Loss: 0.5512, ACC: 0.23979591836734693
epoch:1, bestThreshold:[0.98, 0.02, 0.01, 0.01, 0.01], GM:0.41034483909606934, OAA:0.3275862068965517, ACC:0.4051724137931034, F1:0.4310344457626343
Epoch [3/200], Average Loss: 0.5045, ACC: 0.2602040816326531
epoch:2, bestThreshold:[0.11, 0.01, 0.47, 0.28, 0.02], GM:0.34999996423721313, OAA:0.2413793103448276, ACC:0.3419540229885058, F1:0.3764367401599884
Epoch [4/200], Average Loss: 0.4818, ACC: 0.23979591836734693
epoch:3, bestThreshold:[0.01, 0.83, 0.08, 0.01, 0.37], GM:0.44942528009414673, OAA:0.29310344827586204, ACC:0.4396551724137931, F1:0.48850566148757935
Epoch [5/200], Average Loss: 0.4511, ACC: 0.28061224489795916
epoch:4, bestThreshold:[0.01, 0.15, 0.07, 0.05, 0.98], GM:0.5109195113182068, OAA:0.31034482

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.8103


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: rhlujew0 with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 3
[34m[1mwandb[0m: 	dropout: 0.7
[34m[1mwandb[0m: 	learning_rate: 0.030951522326570054


Epoch [1/200], Average Loss: 0.7355, ACC: 0.19216651046547853
epoch:0, bestThreshold:[0.25, 0.12, 0.42, 0.06, 0.01], GM:0.327873557806015, OAA:0.034482758620689655, ACC:0.29741379310344834, F1:0.3965516984462738
Epoch [2/200], Average Loss: 0.7255, ACC: 0.20449859418931485
epoch:1, bestThreshold:[0.75, 0.98, 0.98, 0.05, 0.01], GM:0.32862067222595215, OAA:0.05172413793103448, ACC:0.2816091954022989, F1:0.37873557209968567
Epoch [3/200], Average Loss: 0.6699, ACC: 0.2063261480787248
epoch:2, bestThreshold:[0.74, 0.27, 0.18, 0.29, 0.98], GM:0.43390804529190063, OAA:0.29310344827586204, ACC:0.41954022988505746, F1:0.4683907926082611
Epoch [4/200], Average Loss: 0.5239, ACC: 0.4042799125273357
epoch:3, bestThreshold:[0.74, 0.61, 0.97, 0.01, 0.96], GM:0.4585632383823395, OAA:0.3448275862068966, ACC:0.44683908045977005, F1:0.48678162693977356
Epoch [5/200], Average Loss: 0.4188, ACC: 0.45278819119025326
epoch:4, bestThreshold:[0.02, 0.2, 0.98, 0.01, 0.78], GM:0.45160919427871704, OAA:0.362068

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.89262


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 2810hq14 with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 7
[34m[1mwandb[0m: 	dropout: 0.5
[34m[1mwandb[0m: 	learning_rate: 0.05035680417475822


Epoch [1/200], Average Loss: 0.7720, ACC: 0.1794628751974724
epoch:0, bestThreshold:[0.96, 0.33, 0.96, 0.01, 0.08], GM:0.3408045768737793, OAA:0.0, ACC:0.2887931034482759, F1:0.4080459475517273
Epoch [2/200], Average Loss: 0.7832, ACC: 0.23736176935229072
epoch:1, bestThreshold:[0.01, 0.97, 0.16, 0.44, 0.05], GM:0.41528740525245667, OAA:0.2413793103448276, ACC:0.39080459770114934, F1:0.44999998807907104
Epoch [3/200], Average Loss: 0.3865, ACC: 0.45023696682464454
epoch:2, bestThreshold:[0.33, 0.92, 0.02, 0.29, 0.88], GM:0.3432183861732483, OAA:0.2413793103448276, ACC:0.3261494252873563, F1:0.36264365911483765
Epoch [4/200], Average Loss: 0.3459, ACC: 0.43364928909952605
epoch:3, bestThreshold:[0.37, 0.97, 0.01, 0.98, 0.02], GM:0.43229883909225464, OAA:0.29310344827586204, ACC:0.4094827586206896, F1:0.46034476161003113
Epoch [5/200], Average Loss: 0.3410, ACC: 0.4786729857819905
epoch:4, bestThreshold:[0.97, 0.98, 0.32, 0.97, 0.29], GM:0.4321839213371277, OAA:0.3448275862068966, ACC:0.

0,1
Grand Mean,▁

0,1
Grand Mean,0.90752


[34m[1mwandb[0m: Agent Starting Run: doivlyy9 with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 5
[34m[1mwandb[0m: 	dropout: 0.4
[34m[1mwandb[0m: 	learning_rate: 0.09149115495446336


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669851867482065, max=1.0…

Epoch [1/200], Average Loss: 0.7405, ACC: 0.196484942886812
epoch:0, bestThreshold:[0.01, 0.01, 0.79, 0.01, 0.01], GM:0.3757471442222595, OAA:0.10344827586206896, ACC:0.3362068965517241, F1:0.43333327770233154
Epoch [2/200], Average Loss: 1.1644, ACC: 0.18430944963655216
epoch:1, bestThreshold:[0.01, 0.01, 0.01, 0.01, 0.01], GM:0.25839078426361084, OAA:0.0, ACC:0.19971264367816097, F1:0.3005746901035309
Epoch [3/200], Average Loss: 1.5729, ACC: 0.21770508826583665
epoch:2, bestThreshold:[0.01, 0.01, 0.01, 0.01, 0.01], GM:0.41781607270240784, OAA:0.06896551724137931, ACC:0.39080459770114945, F1:0.502873420715332
Epoch [4/200], Average Loss: 0.6562, ACC: 0.3375649013499482
epoch:3, bestThreshold:[0.01, 0.01, 0.01, 0.01, 0.97], GM:0.4833332896232605, OAA:0.3275862068965517, ACC:0.4683908045977011, F1:0.5201148390769958
Epoch [5/200], Average Loss: 0.4118, ACC: 0.4058151609553478
epoch:4, bestThreshold:[0.91, 0.01, 0.87, 0.01, 0.01], GM:0.510919451713562, OAA:0.43103448275862066, ACC:0.505

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.89351


In [37]:
smote_multiple

{1695115013: [0.868426913022995,
  0.8271324863883848,
  0.8659860859044162,
  0.8760586857795716,
  0.8941923797130584,
  0.8787649095058441],
 1695118605: [0.8954839587211609,
  0.873653962492438,
  0.8941016333938295,
  0.9073653936386108,
  0.9012855410575866,
  0.901013308763504],
 1695121606: [0.7674086451530456,
  0.6817301875378099,
  0.7583257713248639,
  0.7694242835044861,
  0.8407541871070862,
  0.7868088603019714],
 1695125102: [0.9014761030673981,
  0.8822444041137327,
  0.9001461988304094,
  0.9160465836524964,
  0.9027324020862579,
  0.9062109351158142],
 1695128408: [0.8847529828548432,
  0.8599516031457956,
  0.8830207703165959,
  0.8962946236133575,
  0.893400889635086,
  0.8910970032215119],
 1695132182: [0.9010071754455566,
  0.8753781004234724,
  0.8994933454325471,
  0.9129461586475373,
  0.909724748134613,
  0.9074934542179107],
 1695137335: [0.8916108250617981,
  0.8650635208711435,
  0.8898366606170601,
  0.9022232353687286,
  0.9027979373931885,
  0.898132687

In [38]:
import json
file = open("/home/kongge/projects/new_protT5/data/dictionary_data_only_smote.json", "w")
json.dump(smote_multiple, file)
file.close()