In [1]:
from sklearn.model_selection import KFold
import pandas as pd

In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mkonggedzu[0m ([33mimucs[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
sweep_config = {
    'method': 'random'
}

In [4]:
metric = {
    'name': 'Grand Mean',
    'goal': 'maximize'
}

sweep_config['metric'] = metric

In [5]:
parameters_dict = {
    'dropout': {
          'values': [0.72]
        },

    'learning_rate': {
        'values': [0.0005]
        },

    'batch_size': {
        'values': [4096]
        },

    'data_augmentation_multiple': {
        'values': [5]
        }
}

In [6]:
sweep_config['parameters'] = parameters_dict

In [7]:
sweep_id = wandb.sweep(sweep_config, project="sub_loc_no_auto_threshold")

Create sweep with ID: 52y1ec6h
Sweep URL: https://wandb.ai/imucs/sub_loc_no_auto_threshold/sweeps/52y1ec6h


In [8]:
import pprint

pprint.pprint(sweep_config)

{'method': 'random',
 'metric': {'goal': 'maximize', 'name': 'Grand Mean'},
 'parameters': {'batch_size': {'values': [4096]},
                'data_augmentation_multiple': {'values': [5]},
                'dropout': {'values': [0.72]},
                'learning_rate': {'values': [0.0005]}}}


In [9]:
feature_pd = pd.read_csv('/home/kongge/projects/new_protT5/data/DPC_T5_578_right.csv')
labels_pd = pd.read_csv("/home/kongge/projects/new_protT5/data/mutil_label_578.csv")

In [10]:
from dataAug.tools import MLDA
from dataAug.all_tools import dataAugSMOTE

In [11]:
smote_multiple = {}

In [12]:
import time
from classify.targeTools import testThresholdFive, Accuracy, countScore
from torch import optim
from torch.utils.data import DataLoader, TensorDataset
import torch
from classify.Classify_adjust import ModelClassify
def train_and_val(config=None):
    with wandb.init(config=config):
        config = wandb.config
        kf = KFold(n_splits=10, shuffle=True)
        model_discord = []

        multi_label_samples = labels_pd[(labels_pd.sum(axis=1) >= 2)]
        multi_label_indices = multi_label_samples.index
        multi_features_samples = feature_pd.loc[multi_label_indices]
        ML_G_X, ML_G_y = MLDA(multi_features_samples, multi_label_samples, config.data_augmentation_multiple)

        G_feature, G_label = dataAugSMOTE(feature_pd, labels_pd, config.data_augmentation_multiple, 1424)
        G_feature = pd.concat([G_feature, ML_G_X], axis=0)
        G_label = pd.concat([G_label, ML_G_y], axis=0)
        feature_all = pd.concat([feature_pd, G_feature], axis=0)
        label_all = pd.concat([labels_pd, G_label], axis=0)

        for train_index, test_index in kf.split(feature_all):
            train_data = feature_all.iloc[train_index]
            train_label = label_all.iloc[train_index]

            test_data = feature_all.iloc[test_index]
            test_label = label_all.iloc[test_index]

            datasetTrain = TensorDataset(torch.tensor(train_data.values), torch.tensor(train_label.values))
            batch_size = config.batch_size
            dataloaderTrain = DataLoader(datasetTrain, batch_size=batch_size, shuffle=True)

            datasetTest = TensorDataset(torch.tensor(test_data.values), torch.tensor(test_label.values))

            batch_size = len(datasetTest)
            dataloaderTest = DataLoader(datasetTest, batch_size=batch_size, shuffle=False)

            model = ModelClassify(drop_rate=config.dropout, num_class=5, feature_num=1424)
            criterion = torch.nn.BCELoss()
            optimizer = optim.Adam(model.parameters(), lr=0.001)
            import warnings
            warnings.filterwarnings("ignore")
            GMList = {}
            epochs = 100
            for epoch in range(epochs):
                model.train()
                total_loss = 0.0
                for idx, data in enumerate(dataloaderTrain, 0):
                    inputs, labels = data
                    labels = labels.float()
                    inputs = inputs.float()
                    out = model(inputs)
                    loss = criterion(out, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                avg_loss = total_loss / len(dataloaderTrain)
                threshold = 0.5
                labels_cov = torch.where(out > threshold, torch.tensor(1), torch.tensor(0))
                print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.4f}, ACC: {Accuracy(labels.int(), labels_cov)}")
            model.eval()
            with torch.no_grad():
                for idxTest, dataTest in enumerate(dataloaderTest, 0):
                    inputsTest, targetsFive = dataTest
                    inputsTest = inputsTest.float()
                    test_output = model(inputsTest)
            test_labels_cov = torch.where(test_output > threshold, torch.tensor(1), torch.tensor(0))
            GMScore1 = countScore(targetsFive.int(), test_labels_cov)
            GMList[epoch] = GMScore1
            t = time.time()
            best_key = max(GMList, key=GMList.get)
            best_value = GMList[best_key]
            best_value = [x.item() if isinstance(x, torch.Tensor) else x for x in best_value]
            model_discord.append(best_value)
        model_discord_column_means = [sum(col) / len(col) for col in zip(*model_discord)]
        smote_multiple[int(t)] = model_discord_column_means
        wandb.log({"Grand Mean": model_discord_column_means[0]})

In [13]:
wandb.agent(sweep_id, train_and_val, count=6)

[34m[1mwandb[0m: Agent Starting Run: wpwn5zfv with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 5
[34m[1mwandb[0m: 	dropout: 0.72
[34m[1mwandb[0m: 	learning_rate: 0.0005


  f"After over-sampling, the number of samples ({n_samples})"
  f"After over-sampling, the number of samples ({n_samples})"
  f"After over-sampling, the number of samples ({n_samples})"


Epoch [1/100], Average Loss: 0.7322, ACC: 0.19581201725348615
Epoch [2/100], Average Loss: 0.7178, ACC: 0.21780519610793597
Epoch [3/100], Average Loss: 0.7061, ACC: 0.2280569766275473
Epoch [4/100], Average Loss: 0.6938, ACC: 0.2505918346875339
Epoch [5/100], Average Loss: 0.6757, ACC: 0.274611295014548
Epoch [6/100], Average Loss: 0.6562, ACC: 0.30462433543986733
Epoch [7/100], Average Loss: 0.6410, ACC: 0.32711405356605755
Epoch [8/100], Average Loss: 0.6231, ACC: 0.3553315277359825
Epoch [9/100], Average Loss: 0.6091, ACC: 0.3748871501655125
Epoch [10/100], Average Loss: 0.5973, ACC: 0.39328418096097834
Epoch [11/100], Average Loss: 0.5809, ACC: 0.40755843113652135
Epoch [12/100], Average Loss: 0.5697, ACC: 0.4255191092386376
Epoch [13/100], Average Loss: 0.5570, ACC: 0.43280168522419277
Epoch [14/100], Average Loss: 0.5439, ACC: 0.4599608787240416
Epoch [15/100], Average Loss: 0.5308, ACC: 0.47066405858159904
Epoch [16/100], Average Loss: 0.5200, ACC: 0.4881432440565705
Epoch [17/

0,1
Grand Mean,▁

0,1
Grand Mean,0.94359


[34m[1mwandb[0m: Agent Starting Run: 0km7e4kd with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 5
[34m[1mwandb[0m: 	dropout: 0.72
[34m[1mwandb[0m: 	learning_rate: 0.0005


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016674484033330828, max=1.0…

Epoch [1/100], Average Loss: 0.7282, ACC: 0.19483900090279904
Epoch [2/100], Average Loss: 0.7124, ACC: 0.22297622630153618
Epoch [3/100], Average Loss: 0.7034, ACC: 0.2353245059685038
Epoch [4/100], Average Loss: 0.6829, ACC: 0.2644146855251309
Epoch [5/100], Average Loss: 0.6700, ACC: 0.28168823352392763
Epoch [6/100], Average Loss: 0.6520, ACC: 0.30537666766978
Epoch [7/100], Average Loss: 0.6346, ACC: 0.33389507473167024
Epoch [8/100], Average Loss: 0.6196, ACC: 0.3543835891262924
Epoch [9/100], Average Loss: 0.6054, ACC: 0.3683669375062698
Epoch [10/100], Average Loss: 0.5945, ACC: 0.39234125789948726
Epoch [11/100], Average Loss: 0.5834, ACC: 0.41098906610492386
Epoch [12/100], Average Loss: 0.5713, ACC: 0.4217223392516776
Epoch [13/100], Average Loss: 0.5594, ACC: 0.4296017654729637
Epoch [14/100], Average Loss: 0.5460, ACC: 0.45280369144347055
Epoch [15/100], Average Loss: 0.5346, ACC: 0.4733925168020822
Epoch [16/100], Average Loss: 0.5237, ACC: 0.48477781121476127
Epoch [17/1

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.94925


[34m[1mwandb[0m: Agent Starting Run: q5cf6pqw with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 5
[34m[1mwandb[0m: 	dropout: 0.72
[34m[1mwandb[0m: 	learning_rate: 0.0005


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666941833332961, max=1.0)…

Epoch [1/100], Average Loss: 0.7369, ACC: 0.18932691343163835
Epoch [2/100], Average Loss: 0.7225, ACC: 0.20331527735981628
Epoch [3/100], Average Loss: 0.7118, ACC: 0.22021767479185617
Epoch [4/100], Average Loss: 0.7010, ACC: 0.2326913431638096
Epoch [5/100], Average Loss: 0.6846, ACC: 0.2530895776908441
Epoch [6/100], Average Loss: 0.6693, ACC: 0.2835740796469087
Epoch [7/100], Average Loss: 0.6521, ACC: 0.3024726652623166
Epoch [8/100], Average Loss: 0.6383, ACC: 0.3299428227505291
Epoch [9/100], Average Loss: 0.6273, ACC: 0.34370548700973175
Epoch [10/100], Average Loss: 0.6126, ACC: 0.37060888755140875
Epoch [11/100], Average Loss: 0.5989, ACC: 0.3956414886147048
Epoch [12/100], Average Loss: 0.5857, ACC: 0.4083910121376243
Epoch [13/100], Average Loss: 0.5742, ACC: 0.4223894071622004
Epoch [14/100], Average Loss: 0.5627, ACC: 0.44118266626541897
Epoch [15/100], Average Loss: 0.5497, ACC: 0.45566255391714017
Epoch [16/100], Average Loss: 0.5364, ACC: 0.4698214464840963
Epoch [17/

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.94826


[34m[1mwandb[0m: Agent Starting Run: rccsk8k2 with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 5
[34m[1mwandb[0m: 	dropout: 0.72
[34m[1mwandb[0m: 	learning_rate: 0.0005


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666945006666841, max=1.0)…

Epoch [1/100], Average Loss: 0.7330, ACC: 0.20359614805898352
Epoch [2/100], Average Loss: 0.7223, ACC: 0.21381783528939818
Epoch [3/100], Average Loss: 0.7090, ACC: 0.22217875413782878
Epoch [4/100], Average Loss: 0.6915, ACC: 0.25467449092186023
Epoch [5/100], Average Loss: 0.6804, ACC: 0.26757448089076474
Epoch [6/100], Average Loss: 0.6627, ACC: 0.2966797070919888
Epoch [7/100], Average Loss: 0.6440, ACC: 0.32738990871702595
Epoch [8/100], Average Loss: 0.6297, ACC: 0.3480138429130317
Epoch [9/100], Average Loss: 0.6148, ACC: 0.36678703982345284
Epoch [10/100], Average Loss: 0.6024, ACC: 0.3874711605978525
Epoch [11/100], Average Loss: 0.5895, ACC: 0.403977329722137
Epoch [12/100], Average Loss: 0.5795, ACC: 0.42992276055772577
Epoch [13/100], Average Loss: 0.5650, ACC: 0.44205035610391924
Epoch [14/100], Average Loss: 0.5520, ACC: 0.45639482395425474
Epoch [15/100], Average Loss: 0.5410, ACC: 0.475494031497639
Epoch [16/100], Average Loss: 0.5277, ACC: 0.49303340355100383
Epoch [1

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.94232


[34m[1mwandb[0m: Agent Starting Run: 3wlpb8fs with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 5
[34m[1mwandb[0m: 	dropout: 0.72
[34m[1mwandb[0m: 	learning_rate: 0.0005


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0166709783500058, max=1.0))…

Epoch [1/100], Average Loss: 0.7279, ACC: 0.2026732871902906
Epoch [2/100], Average Loss: 0.7131, ACC: 0.21557327715919475
Epoch [3/100], Average Loss: 0.7001, ACC: 0.23084060587822453
Epoch [4/100], Average Loss: 0.6912, ACC: 0.24584712609088383
Epoch [5/100], Average Loss: 0.6723, ACC: 0.2695405757849365
Epoch [6/100], Average Loss: 0.6577, ACC: 0.29396127996790383
Epoch [7/100], Average Loss: 0.6404, ACC: 0.3230715217173265
Epoch [8/100], Average Loss: 0.6283, ACC: 0.3403601163607204
Epoch [9/100], Average Loss: 0.6124, ACC: 0.36619520513592185
Epoch [10/100], Average Loss: 0.6030, ACC: 0.3719179456314572
Epoch [11/100], Average Loss: 0.5891, ACC: 0.39547095997592485
Epoch [12/100], Average Loss: 0.5787, ACC: 0.4123934196007594
Epoch [13/100], Average Loss: 0.5661, ACC: 0.429802387400941
Epoch [14/100], Average Loss: 0.5563, ACC: 0.4485605376667632
Epoch [15/100], Average Loss: 0.5453, ACC: 0.45494533052462244
Epoch [16/100], Average Loss: 0.5338, ACC: 0.4729160397231388
Epoch [17/1

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.95105


[34m[1mwandb[0m: Agent Starting Run: 8ayk1p1v with config:
[34m[1mwandb[0m: 	batch_size: 4096
[34m[1mwandb[0m: 	data_augmentation_multiple: 5
[34m[1mwandb[0m: 	dropout: 0.72
[34m[1mwandb[0m: 	learning_rate: 0.0005


Epoch [1/100], Average Loss: 0.7355, ACC: 0.18588123181863755
Epoch [2/100], Average Loss: 0.7147, ACC: 0.210999097201325
Epoch [3/100], Average Loss: 0.7034, ACC: 0.22938609690039266
Epoch [4/100], Average Loss: 0.6900, ACC: 0.25333533955261545
Epoch [5/100], Average Loss: 0.6744, ACC: 0.2795215167017786
Epoch [6/100], Average Loss: 0.6574, ACC: 0.303666365733778
Epoch [7/100], Average Loss: 0.6425, ACC: 0.3240696158090109
Epoch [8/100], Average Loss: 0.6245, ACC: 0.35747316681713426
Epoch [9/100], Average Loss: 0.6094, ACC: 0.38621727354799784
Epoch [10/100], Average Loss: 0.5945, ACC: 0.39778312769585533
Epoch [11/100], Average Loss: 0.5819, ACC: 0.4208295716721819
Epoch [12/100], Average Loss: 0.5650, ACC: 0.44544588223492465
Epoch [13/100], Average Loss: 0.5527, ACC: 0.459725147958668
Epoch [14/100], Average Loss: 0.5404, ACC: 0.4770137426020625
Epoch [15/100], Average Loss: 0.5264, ACC: 0.4888203430634927
Epoch [16/100], Average Loss: 0.5160, ACC: 0.5088323803791707
Epoch [17/100

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Grand Mean,▁

0,1
Grand Mean,0.94367


In [14]:
smote_multiple

{1695264989: [0.9435944139957428,
  0.907664249615469,
  0.9407425108034865,
  0.9488658666610718,
  0.9684947252273559,
  0.9522046744823456],
 1695265619: [0.9492540955543518,
  0.9201311067164726,
  0.9470234869015357,
  0.9550993323326111,
  0.9677343249320984,
  0.9562821924686432],
 1695266250: [0.9482617020606995,
  0.9163194902219292,
  0.9459054176127346,
  0.9536245942115784,
  0.9694416344165802,
  0.9560174107551574],
 1695266890: [0.9423164606094361,
  0.9068497766058743,
  0.939609316633707,
  0.9480035603046417,
  0.9662403702735901,
  0.9508794009685516],
 1695267532: [0.9510518074035644,
  0.9211968065626603,
  0.9489555409067606,
  0.9571232199668884,
  0.9696305990219116,
  0.9583529591560364],
 1695268163: [0.9436653673648834,
  0.9092953929539295,
  0.9411537146903001,
  0.9492987215518951,
  0.966559362411499,
  0.952019739151001]}

In [15]:
import json
file = open("/home/kongge/projects/new_protT5/data/dictionary_data_no_auto_threshold.json", "w")
json.dump(smote_multiple, file)
file.close()