In [7]:
# Pre-train용
import os
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, TensorDataset
import torch
from sklearn.model_selection import train_test_split

class CustomBertForSequenceClassification(BertForSequenceClassification):
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        labels=None,
        output_hidden_states=True
    ):
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            labels=labels,
            output_hidden_states=output_hidden_states
        )
        logits = outputs.logits
        hidden_states = outputs.hidden_states[-5]  # n번째 레이어의 hidden states를 반환합니다.
        loss = outputs.loss
        return logits, loss, hidden_states

# 데이터 로드 및 전처리
data_A = pd.read_csv("output1.csv")  # data set A 파일명에 맞게 수정
data_B = pd.read_csv("infected.csv")  # data set B 파일명에 맞게 수정
# 모델 저장 경로
model_path = "Pre-trained.pt"

# X_train, Y_train 생성
X_train = []
Y_train = []

for index, row in data_A.iterrows():  # 중복 제거를 하지 않고 원본 데이터 사용
    patient_id = row["ID"]
    patient_info = [str(row[column]) for column in data_A.columns if column != "ID" and column != "DESCRIPTION"]
    symptoms = ", ".join(data_A[data_A["ID"] == patient_id]["DESCRIPTION"].tolist())
    combined_info = ", ".join(patient_info) + ", " + symptoms
    X_train.append(combined_info)
    if patient_id in data_B.values:
        Y_train.append(1)
    else:
        Y_train.append(0)

print("X_train\n", X_train[:10])
print("Y_train\n", Y_train[:10])
        
# BERT 토크나이저 및 모델 로드
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# 모델이 이미 저장되어 있는지 확인하고, 저장된 모델이 있으면 불러오고 없으면 새로운 모델 생성
if os.path.exists(model_path):
    # 저장된 모델이 있을 경우 불러오기
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    model.load_state_dict(torch.load(model_path))
    print("Pre-train model loaded.")
else:
    # 저장된 모델이 없을 경우 새로운 모델 생성
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    print("New model generated.")

# 입력 데이터를 BERT의 입력 형식으로 변환
max_len = 128  # 입력 시퀀스의 최대 길이

input_ids = []
attention_masks = []

for info in X_train:
    encoded_dict = tokenizer.encode_plus(
                        info,                         # 환자 정보 및 증상
                        add_special_tokens = True,    # [CLS], [SEP] 토큰 추가
                        max_length = max_len,         # 최대 길이 지정
                        pad_to_max_length = True,     # 패딩을 추가하여 최대 길이로 맞춤
                        return_attention_mask = True, # 어텐션 마스크 생성
                        return_tensors = 'pt',        # PyTorch 텐서로 반환
                   )
    
    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(Y_train)

# 데이터셋 및 데이터로더 생성
dataset = TensorDataset(input_ids, attention_masks, labels)
train_size = 0.8
train_dataset, val_dataset = train_test_split(dataset, test_size=1-train_size, random_state=42)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True)

# GPU 사용 가능 여부 확인
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

# 모델을 GPU로 이동
model.to(device)

# 옵티마이저 및 학습률 설정
# 기본 학습률 : 2e-6
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-6)

# 에폭 설정
epochs = 3

# 학습 루프
hidden_states_list = []  # 모든 에폭에 대한 hidden state를 저장할 리스트
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in train_dataloader:
        batch = tuple(t.to(device) for t in batch)
        inputs = {'input_ids': batch[0],
                  'attention_mask': batch[1],
                  'labels': batch[2]}
        optimizer.zero_grad()
        outputs = model(**inputs)
        loss = outputs[1]  # loss가 outputs의 두 번째 값입니다.
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch + 1}/{epochs}, Batch Loss: {loss.item()}')
        # hidden state를 저장합니다.
        #hidden_states = outputs[2]
        #hidden_states_list.append(hidden_states)

    avg_train_loss = total_loss / len(train_dataloader)
    print(f'Epoch {epoch + 1}/{epochs}, Average Training Loss: {avg_train_loss}')

# 모든 에폭에 대한 hidden state를 합쳐서 CSV 파일로 저장합니다.
#hidden_states_concat = torch.cat(hidden_states_list, dim=0)
#hidden_states_concat = hidden_states_concat[:, 0, :].cpu().detach().numpy()
#hidden_states_df = pd.DataFrame(hidden_states_concat)
#hidden_states_df.to_csv("hidden_states_all_epochs.csv", index=False)

# 모델 저장
torch.save(model.state_dict(), model_path)

# 모델 평가
model.eval()
val_accuracy = 0
for batch in val_dataloader:
    batch = tuple(t.to(device) for t in batch)
    inputs = {'input_ids': batch[0],
              'attention_mask': batch[1],
              'labels': batch[2]}
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs[0]  # logits가 outputs의 첫 번째 값입니다.
    logits = logits.detach().cpu().numpy()
    label_ids = inputs['labels'].cpu().numpy()
    val_accuracy += (logits.argmax(axis=1) == label_ids).mean().item()

print(f'Validation Accuracy: {val_accuracy / len(val_dataloader)}')


X_train
 ["5/11/1967, nan, 999-83-3739, S99963041, X22481021X, Mrs., Hortense60, O'Hara248, nan, Spinka232, M, white, nonhispanic, F, Northborough  Massachusetts  US, 1040 Turner Knoll, Milford, Massachusetts, Worcester County, nan, 42.13519473, -71.50138606, 1137082.44, 9598.16, 11/12/2019, 1/21/2020, 803d9786-29a1-466c-9365-0205ea0a031c, 36971009.0, Sinusitis (disorder), Body mass index 30+ - obesity (finding), Prediabetes, Chronic sinusitis (disorder)", '11/19/1965, nan, 999-87-9895, S99998149, X88597197X, Mr., Waylon572, Reinger292, nan, nan, M, white, nonhispanic, M, Seekonk  Massachusetts  US, 895 Robel Light, Worcester, Massachusetts, Worcester County, 1604.0, 42.32627029, -71.79641283, 1134089.03, 5790.2, 5/7/2004, nan, ef7b5155-d4be-4b0f-98d4-1313f9b7e90b, 162864005.0, Body mass index 30+ - obesity (finding), Vomiting symptom (finding), Bacterial infectious disease (disorder), Pneumonia (disorder), Hypertension, Chronic sinusitis (disorder), Fever (finding), Hyperlipidemia, Hy

Some weights of CustomBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


New model generated.
True
Epoch 1/3, Batch Loss: 0.6847525238990784
Epoch 1/3, Batch Loss: 0.6722670793533325
Epoch 1/3, Batch Loss: 0.6310189366340637
Epoch 1/3, Batch Loss: 0.6308815479278564
Epoch 1/3, Batch Loss: 0.600823163986206
Epoch 1/3, Batch Loss: 0.6389783024787903
Epoch 1/3, Batch Loss: 0.5263268351554871
Epoch 1/3, Batch Loss: 0.6065617799758911
Epoch 1/3, Batch Loss: 0.5318933725357056
Epoch 1/3, Batch Loss: 0.6635839343070984
Epoch 1/3, Batch Loss: 0.5008132457733154
Epoch 1/3, Batch Loss: 0.48957595229148865
Epoch 1/3, Batch Loss: 0.4432156980037689
Epoch 1/3, Batch Loss: 0.5376770496368408
Epoch 1/3, Batch Loss: 0.5094747543334961
Epoch 1/3, Batch Loss: 0.6337817311286926
Epoch 1/3, Batch Loss: 0.38933855295181274
Epoch 1/3, Batch Loss: 0.4315186142921448
Epoch 1/3, Batch Loss: 0.44252073764801025
Epoch 1/3, Batch Loss: 0.4495561122894287
Epoch 1/3, Batch Loss: 0.3830277621746063
Epoch 1/3, Batch Loss: 0.6138479709625244
Epoch 1/3, Batch Loss: 0.6559330821037292
Epoch 

Epoch 1/3, Batch Loss: 0.2997588813304901
Epoch 1/3, Batch Loss: 0.4131537079811096
Epoch 1/3, Batch Loss: 0.35360416769981384
Epoch 1/3, Batch Loss: 0.3216302692890167
Epoch 1/3, Batch Loss: 0.2872118651866913
Epoch 1/3, Batch Loss: 0.28599852323532104
Epoch 1/3, Batch Loss: 0.44848763942718506
Epoch 1/3, Batch Loss: 0.4808341860771179
Epoch 1/3, Batch Loss: 0.13023139536380768
Epoch 1/3, Batch Loss: 0.26854580640792847
Epoch 1/3, Batch Loss: 0.21243765950202942
Epoch 1/3, Batch Loss: 0.274191290140152
Epoch 1/3, Batch Loss: 0.19439703226089478
Epoch 1/3, Batch Loss: 0.30279526114463806
Epoch 1/3, Batch Loss: 0.47825756669044495
Epoch 1/3, Batch Loss: 0.24582137167453766
Epoch 1/3, Batch Loss: 0.4181859493255615
Epoch 1/3, Batch Loss: 0.4626930058002472
Epoch 1/3, Batch Loss: 0.13820762932300568
Epoch 1/3, Batch Loss: 0.12077299505472183
Epoch 1/3, Batch Loss: 0.1602656990289688
Epoch 1/3, Batch Loss: 0.2628326117992401
Epoch 1/3, Batch Loss: 0.4668749272823334
Epoch 1/3, Batch Loss: 

Epoch 1/3, Batch Loss: 0.5151466131210327
Epoch 1/3, Batch Loss: 0.19452433288097382
Epoch 1/3, Batch Loss: 0.20477958023548126
Epoch 1/3, Batch Loss: 0.19719131290912628
Epoch 1/3, Batch Loss: 0.07704883068799973
Epoch 1/3, Batch Loss: 0.21937695145606995
Epoch 1/3, Batch Loss: 0.2202247530221939
Epoch 1/3, Batch Loss: 0.1811894178390503
Epoch 1/3, Batch Loss: 0.3259986340999603
Epoch 1/3, Batch Loss: 0.24670398235321045
Epoch 1/3, Batch Loss: 0.20213356614112854
Epoch 1/3, Batch Loss: 0.20732854306697845
Epoch 1/3, Batch Loss: 0.18254125118255615
Epoch 1/3, Batch Loss: 0.2918875813484192
Epoch 1/3, Batch Loss: 0.1025620549917221
Epoch 1/3, Batch Loss: 0.10840588808059692
Epoch 1/3, Batch Loss: 0.08697187900543213
Epoch 1/3, Batch Loss: 0.243134006857872
Epoch 1/3, Batch Loss: 0.22642536461353302
Epoch 1/3, Batch Loss: 0.18188968300819397
Epoch 1/3, Batch Loss: 0.09421464800834656
Epoch 1/3, Batch Loss: 0.1958257257938385
Epoch 1/3, Batch Loss: 0.18228699266910553
Epoch 1/3, Batch Los

Epoch 1/3, Batch Loss: 0.3132864832878113
Epoch 1/3, Batch Loss: 0.15930086374282837
Epoch 1/3, Batch Loss: 0.21106639504432678
Epoch 1/3, Batch Loss: 0.1603165715932846
Epoch 1/3, Batch Loss: 0.18348369002342224
Epoch 1/3, Batch Loss: 0.10267043858766556
Epoch 1/3, Batch Loss: 0.3880804181098938
Epoch 1/3, Batch Loss: 0.38576188683509827
Epoch 1/3, Batch Loss: 0.043030720204114914
Epoch 1/3, Batch Loss: 0.11836417019367218
Epoch 1/3, Batch Loss: 0.2222587913274765
Epoch 1/3, Batch Loss: 0.08058943599462509
Epoch 1/3, Batch Loss: 0.07907600700855255
Epoch 1/3, Batch Loss: 0.06680934876203537
Epoch 1/3, Batch Loss: 0.15391474962234497
Epoch 1/3, Batch Loss: 0.1940966248512268
Epoch 1/3, Batch Loss: 0.1617904156446457
Epoch 1/3, Batch Loss: 0.15227405726909637
Epoch 1/3, Batch Loss: 0.21193785965442657
Epoch 1/3, Batch Loss: 0.05788108706474304
Epoch 1/3, Batch Loss: 0.29663437604904175
Epoch 1/3, Batch Loss: 0.3121855854988098
Epoch 1/3, Batch Loss: 0.19365228712558746
Epoch 1/3, Batch 

Epoch 1/3, Batch Loss: 0.3142750561237335
Epoch 1/3, Batch Loss: 0.08850397169589996
Epoch 1/3, Batch Loss: 0.042354293167591095
Epoch 1/3, Batch Loss: 0.1900554895401001
Epoch 1/3, Batch Loss: 0.2751230299472809
Epoch 1/3, Batch Loss: 0.17310628294944763
Epoch 1/3, Batch Loss: 0.4526957869529724
Epoch 1/3, Batch Loss: 0.1371900737285614
Epoch 1/3, Batch Loss: 0.6451597213745117
Epoch 1/3, Batch Loss: 0.15534162521362305
Epoch 1/3, Batch Loss: 0.05948261916637421
Epoch 1/3, Batch Loss: 0.05741928145289421
Epoch 1/3, Batch Loss: 0.1743171364068985
Epoch 1/3, Batch Loss: 0.047810766845941544
Epoch 1/3, Batch Loss: 0.05565940961241722
Epoch 1/3, Batch Loss: 0.06692738831043243
Epoch 1/3, Batch Loss: 0.18230843544006348
Epoch 1/3, Batch Loss: 0.1962898224592209
Epoch 1/3, Batch Loss: 0.08231262862682343
Epoch 1/3, Batch Loss: 0.2097150683403015
Epoch 1/3, Batch Loss: 0.1851988285779953
Epoch 1/3, Batch Loss: 0.19294029474258423
Epoch 1/3, Batch Loss: 0.045983292162418365
Epoch 1/3, Batch L

Epoch 1/3, Batch Loss: 0.06396404653787613
Epoch 1/3, Batch Loss: 0.17407001554965973
Epoch 1/3, Batch Loss: 0.18024064600467682
Epoch 1/3, Batch Loss: 0.14668521285057068
Epoch 1/3, Batch Loss: 0.03605607897043228
Epoch 1/3, Batch Loss: 0.05969712510704994
Epoch 1/3, Batch Loss: 0.08843085169792175
Epoch 1/3, Batch Loss: 0.24220125377178192
Epoch 1/3, Batch Loss: 0.052308663725852966
Epoch 1/3, Batch Loss: 0.2563210725784302
Epoch 1/3, Batch Loss: 0.31764259934425354
Epoch 1/3, Batch Loss: 0.04934941604733467
Epoch 1/3, Batch Loss: 0.145364910364151
Epoch 1/3, Batch Loss: 0.20203407108783722
Epoch 1/3, Batch Loss: 0.05568137392401695
Epoch 1/3, Batch Loss: 0.07000784575939178
Epoch 1/3, Batch Loss: 0.03473726287484169
Epoch 1/3, Batch Loss: 0.08432136476039886
Epoch 1/3, Batch Loss: 0.05869803577661514
Epoch 1/3, Batch Loss: 0.026348445564508438
Epoch 1/3, Batch Loss: 0.1771460324525833
Epoch 1/3, Batch Loss: 0.047879658639431
Epoch 1/3, Batch Loss: 0.2023952156305313
Epoch 1/3, Batch

Epoch 1/3, Batch Loss: 0.05945700779557228
Epoch 1/3, Batch Loss: 0.18180575966835022
Epoch 1/3, Batch Loss: 0.2680610418319702
Epoch 1/3, Batch Loss: 0.19455693662166595
Epoch 1/3, Batch Loss: 0.17297576367855072
Epoch 1/3, Batch Loss: 0.13890619575977325
Epoch 1/3, Batch Loss: 0.05680760368704796
Epoch 1/3, Batch Loss: 0.06670676916837692
Epoch 1/3, Batch Loss: 0.15402205288410187
Epoch 1/3, Batch Loss: 0.19614946842193604
Epoch 1/3, Batch Loss: 0.24469983577728271
Epoch 1/3, Batch Loss: 0.2708611488342285
Epoch 1/3, Batch Loss: 0.2236962914466858
Epoch 1/3, Batch Loss: 0.11806375533342361
Epoch 1/3, Batch Loss: 0.16313757002353668
Epoch 1/3, Batch Loss: 0.15721410512924194
Epoch 1/3, Batch Loss: 0.2169536054134369
Epoch 1/3, Batch Loss: 0.1053403690457344
Epoch 1/3, Batch Loss: 0.17479336261749268
Epoch 1/3, Batch Loss: 0.20081643760204315
Epoch 1/3, Batch Loss: 0.2858700454235077
Epoch 1/3, Batch Loss: 0.13888117671012878
Epoch 1/3, Batch Loss: 0.1818419098854065
Epoch 1/3, Batch L

Epoch 1/3, Batch Loss: 0.2511742413043976
Epoch 1/3, Batch Loss: 0.18991993367671967
Epoch 1/3, Batch Loss: 0.09461219608783722
Epoch 1/3, Batch Loss: 0.21556925773620605
Epoch 1/3, Batch Loss: 0.14709670841693878
Epoch 1/3, Batch Loss: 0.18512916564941406
Epoch 1/3, Batch Loss: 0.11960967630147934
Epoch 1/3, Batch Loss: 0.16115976870059967
Epoch 1/3, Batch Loss: 0.07247855514287949
Epoch 1/3, Batch Loss: 0.10911396145820618
Epoch 1/3, Batch Loss: 0.0985770896077156
Epoch 1/3, Batch Loss: 0.0795668363571167
Epoch 1/3, Batch Loss: 0.4153788983821869
Epoch 1/3, Batch Loss: 0.1319395899772644
Epoch 1/3, Batch Loss: 0.11372227966785431
Epoch 1/3, Batch Loss: 0.03989388421177864
Epoch 1/3, Batch Loss: 0.2147664576768875
Epoch 1/3, Batch Loss: 0.17338162660598755
Epoch 1/3, Batch Loss: 0.1698719561100006
Epoch 1/3, Batch Loss: 0.20324721932411194
Epoch 1/3, Batch Loss: 0.3053728938102722
Epoch 1/3, Batch Loss: 0.16746240854263306
Epoch 1/3, Batch Loss: 0.1853906214237213
Epoch 1/3, Batch Los

Epoch 1/3, Batch Loss: 0.18773837387561798
Epoch 1/3, Batch Loss: 0.15790770947933197
Epoch 1/3, Batch Loss: 0.29192981123924255
Epoch 1/3, Batch Loss: 0.41845715045928955
Epoch 1/3, Batch Loss: 0.12420664727687836
Epoch 1/3, Batch Loss: 0.3355643153190613
Epoch 1/3, Batch Loss: 0.26860710978507996
Epoch 1/3, Batch Loss: 0.1450396329164505
Epoch 1/3, Batch Loss: 0.054723143577575684
Epoch 1/3, Batch Loss: 0.06839268654584885
Epoch 1/3, Batch Loss: 0.5103667974472046
Epoch 1/3, Batch Loss: 0.08298058062791824
Epoch 1/3, Batch Loss: 0.07147777080535889
Epoch 1/3, Batch Loss: 0.18100646138191223
Epoch 1/3, Batch Loss: 0.05693037807941437
Epoch 1/3, Batch Loss: 0.17537479102611542
Epoch 1/3, Batch Loss: 0.27657845616340637
Epoch 1/3, Batch Loss: 0.34013739228248596
Epoch 1/3, Batch Loss: 0.2819605767726898
Epoch 1/3, Batch Loss: 0.033602070063352585
Epoch 1/3, Batch Loss: 0.1809990406036377
Epoch 1/3, Batch Loss: 0.3005358576774597
Epoch 1/3, Batch Loss: 0.3229104280471802
Epoch 1/3, Batch

Epoch 1/3, Batch Loss: 0.19173882901668549
Epoch 1/3, Batch Loss: 0.18809978663921356
Epoch 1/3, Batch Loss: 0.1539665162563324
Epoch 1/3, Batch Loss: 0.0983661636710167
Epoch 1/3, Batch Loss: 0.2921711802482605
Epoch 1/3, Batch Loss: 0.10908826440572739
Epoch 1/3, Batch Loss: 0.16060270369052887
Epoch 1/3, Batch Loss: 0.3482244610786438
Epoch 1/3, Batch Loss: 0.21117591857910156
Epoch 1/3, Batch Loss: 0.16498416662216187
Epoch 1/3, Batch Loss: 0.2926059663295746
Epoch 1/3, Batch Loss: 0.2788853347301483
Epoch 1/3, Batch Loss: 0.2740228772163391
Epoch 1/3, Batch Loss: 0.06861525028944016
Epoch 1/3, Batch Loss: 0.1886432021856308
Epoch 1/3, Batch Loss: 0.2007370889186859
Epoch 1/3, Batch Loss: 0.17294582724571228
Epoch 1/3, Batch Loss: 0.23831519484519958
Epoch 1/3, Batch Loss: 0.16331973671913147
Epoch 1/3, Batch Loss: 0.07222267240285873
Epoch 1/3, Batch Loss: 0.09453059732913971
Epoch 1/3, Batch Loss: 0.17076431214809418
Epoch 1/3, Batch Loss: 0.04184817522764206
Epoch 1/3, Batch Los

Epoch 1/3, Batch Loss: 0.15193425118923187
Epoch 1/3, Batch Loss: 0.30525127053260803
Epoch 1/3, Batch Loss: 0.057541921734809875
Epoch 1/3, Batch Loss: 0.2823999226093292
Epoch 1/3, Batch Loss: 0.09757263213396072
Epoch 1/3, Batch Loss: 0.0876365602016449
Epoch 1/3, Batch Loss: 0.09413307905197144
Epoch 1/3, Batch Loss: 0.2716563045978546
Epoch 1/3, Batch Loss: 0.41881993412971497
Epoch 1/3, Batch Loss: 0.09213100373744965
Epoch 1/3, Batch Loss: 0.18147487938404083
Epoch 1/3, Batch Loss: 0.21745814383029938
Epoch 1/3, Batch Loss: 0.07547979056835175
Epoch 1/3, Batch Loss: 0.1877678632736206
Epoch 1/3, Batch Loss: 0.042236000299453735
Epoch 1/3, Batch Loss: 0.22969374060630798
Epoch 1/3, Batch Loss: 0.06080682575702667
Epoch 1/3, Batch Loss: 0.15336015820503235
Epoch 1/3, Batch Loss: 0.05379848554730415
Epoch 1/3, Batch Loss: 0.2367718517780304
Epoch 1/3, Batch Loss: 0.303947776556015
Epoch 1/3, Batch Loss: 0.1918179988861084
Epoch 1/3, Batch Loss: 0.05100443959236145
Epoch 1/3, Batch 

Epoch 1/3, Batch Loss: 0.05848256126046181
Epoch 1/3, Batch Loss: 0.18694061040878296
Epoch 1/3, Batch Loss: 0.3859572410583496
Epoch 1/3, Batch Loss: 0.07411744445562363
Epoch 1/3, Batch Loss: 0.10388851910829544
Epoch 1/3, Batch Loss: 0.07537150382995605
Epoch 1/3, Batch Loss: 0.5533769130706787
Epoch 1/3, Batch Loss: 0.19393351674079895
Epoch 1/3, Batch Loss: 0.19025667011737823
Epoch 1/3, Batch Loss: 0.25644606351852417
Epoch 1/3, Batch Loss: 0.1436336189508438
Epoch 1/3, Batch Loss: 0.2781325876712799
Epoch 1/3, Batch Loss: 0.30121952295303345
Epoch 1/3, Batch Loss: 0.19696220755577087
Epoch 1/3, Batch Loss: 0.36385005712509155
Epoch 1/3, Batch Loss: 0.1986340582370758
Epoch 1/3, Batch Loss: 0.06502650678157806
Epoch 1/3, Batch Loss: 0.27368125319480896
Epoch 1/3, Batch Loss: 0.19743506610393524
Epoch 1/3, Batch Loss: 0.17112977802753448
Epoch 1/3, Batch Loss: 0.11869880557060242
Epoch 1/3, Batch Loss: 0.09050542861223221
Epoch 1/3, Batch Loss: 0.13959144055843353
Epoch 1/3, Batch

Epoch 1/3, Batch Loss: 0.08466293662786484
Epoch 1/3, Batch Loss: 0.07650694996118546
Epoch 1/3, Batch Loss: 0.17746371030807495
Epoch 1/3, Batch Loss: 0.06403017789125443
Epoch 1/3, Batch Loss: 0.27871835231781006
Epoch 1/3, Batch Loss: 0.37334656715393066
Epoch 1/3, Batch Loss: 0.05940478295087814
Epoch 1/3, Batch Loss: 0.26547524333000183
Epoch 1/3, Batch Loss: 0.151270791888237
Epoch 1/3, Batch Loss: 0.3246210813522339
Epoch 1/3, Batch Loss: 0.5055840015411377
Epoch 1/3, Batch Loss: 0.018080592155456543
Epoch 1/3, Batch Loss: 0.09708542376756668
Epoch 1/3, Batch Loss: 0.1839064210653305
Epoch 1/3, Batch Loss: 0.30328598618507385
Epoch 1/3, Batch Loss: 0.05919354408979416
Epoch 1/3, Batch Loss: 0.4801120460033417
Epoch 1/3, Batch Loss: 0.076096311211586
Epoch 1/3, Batch Loss: 0.12294584512710571
Epoch 1/3, Batch Loss: 0.05501044914126396
Epoch 1/3, Batch Loss: 0.10311960428953171
Epoch 1/3, Batch Loss: 0.2757534384727478
Epoch 1/3, Batch Loss: 0.08247718960046768
Epoch 1/3, Batch Lo

Epoch 1/3, Batch Loss: 0.0831168070435524
Epoch 1/3, Batch Loss: 0.08844761550426483
Epoch 1/3, Batch Loss: 0.38941317796707153
Epoch 1/3, Batch Loss: 0.24821893870830536
Epoch 1/3, Batch Loss: 0.04137640818953514
Epoch 1/3, Batch Loss: 0.13325199484825134
Epoch 1/3, Batch Loss: 0.12933988869190216
Epoch 1/3, Batch Loss: 0.06931310892105103
Epoch 1/3, Batch Loss: 0.12330003082752228
Epoch 1/3, Batch Loss: 0.041495293378829956
Epoch 1/3, Batch Loss: 0.27878984808921814
Epoch 1/3, Batch Loss: 0.5898592472076416
Epoch 1/3, Batch Loss: 0.1610592156648636
Epoch 1/3, Batch Loss: 0.3235996663570404
Epoch 1/3, Batch Loss: 0.040566280484199524
Epoch 1/3, Batch Loss: 0.04764987528324127
Epoch 1/3, Batch Loss: 0.13975068926811218
Epoch 1/3, Batch Loss: 0.1439003050327301
Epoch 1/3, Batch Loss: 0.057597748935222626
Epoch 1/3, Batch Loss: 0.3872491717338562
Epoch 1/3, Batch Loss: 0.04987698793411255
Epoch 1/3, Batch Loss: 0.17236186563968658
Epoch 1/3, Batch Loss: 0.0821191817522049
Epoch 1/3, Batc

Epoch 1/3, Batch Loss: 0.20103253424167633
Epoch 1/3, Batch Loss: 0.3224886357784271
Epoch 1/3, Batch Loss: 0.13920485973358154
Epoch 1/3, Batch Loss: 0.0966327041387558
Epoch 1/3, Batch Loss: 0.42907580733299255
Epoch 1/3, Batch Loss: 0.1817830353975296
Epoch 1/3, Batch Loss: 0.26058197021484375
Epoch 1/3, Batch Loss: 0.17612382769584656
Epoch 1/3, Batch Loss: 0.30636170506477356
Epoch 1/3, Batch Loss: 0.09050743281841278
Epoch 1/3, Batch Loss: 0.09357503801584244
Epoch 1/3, Batch Loss: 0.4128403663635254
Epoch 1/3, Batch Loss: 0.09189380705356598
Epoch 1/3, Batch Loss: 0.16699737310409546
Epoch 1/3, Batch Loss: 0.17722484469413757
Epoch 1/3, Batch Loss: 0.09971687942743301
Epoch 1/3, Batch Loss: 0.2153252512216568
Epoch 1/3, Batch Loss: 0.10803846269845963
Epoch 1/3, Batch Loss: 0.218511700630188
Epoch 1/3, Batch Loss: 0.05463016778230667
Epoch 1/3, Batch Loss: 0.2058076560497284
Epoch 1/3, Batch Loss: 0.08766695111989975
Epoch 1/3, Batch Loss: 0.10780646651983261
Epoch 1/3, Batch Lo

Epoch 1/3, Batch Loss: 0.26985904574394226
Epoch 1/3, Batch Loss: 0.3207997977733612
Epoch 1/3, Batch Loss: 0.17809170484542847
Epoch 1/3, Batch Loss: 0.09319472312927246
Epoch 1/3, Batch Loss: 0.08759085088968277
Epoch 1/3, Batch Loss: 0.05053038150072098
Epoch 1/3, Batch Loss: 0.1112159937620163
Epoch 1/3, Batch Loss: 0.0397026501595974
Epoch 1/3, Batch Loss: 0.1897367537021637
Epoch 1/3, Batch Loss: 0.17781700193881989
Epoch 1/3, Batch Loss: 0.219825878739357
Epoch 1/3, Batch Loss: 0.17468279600143433
Epoch 1/3, Batch Loss: 0.05199482664465904
Epoch 1/3, Batch Loss: 0.38720160722732544
Epoch 1/3, Batch Loss: 0.14501388370990753
Epoch 1/3, Batch Loss: 0.20172293484210968
Epoch 1/3, Batch Loss: 0.261429101228714
Epoch 1/3, Batch Loss: 0.053209107369184494
Epoch 1/3, Batch Loss: 0.5853615999221802
Epoch 1/3, Batch Loss: 0.27978062629699707
Epoch 1/3, Batch Loss: 0.0598919652402401
Epoch 1/3, Batch Loss: 0.4399104714393616
Epoch 1/3, Batch Loss: 0.04831802845001221
Epoch 1/3, Batch Loss

Epoch 1/3, Batch Loss: 0.18057593703269958
Epoch 1/3, Batch Loss: 0.058431219309568405
Epoch 1/3, Batch Loss: 0.14661267399787903
Epoch 1/3, Batch Loss: 0.26015588641166687
Epoch 1/3, Batch Loss: 0.2713329792022705
Epoch 1/3, Batch Loss: 0.02010628767311573
Epoch 1/3, Batch Loss: 0.14995673298835754
Epoch 1/3, Batch Loss: 0.23013107478618622
Epoch 1/3, Batch Loss: 0.10984209924936295
Epoch 1/3, Batch Loss: 0.23487915098667145
Epoch 1/3, Batch Loss: 0.20923174917697906
Epoch 1/3, Batch Loss: 0.038635026663541794
Epoch 1/3, Batch Loss: 0.10253053158521652
Epoch 1/3, Batch Loss: 0.26026299595832825
Epoch 1/3, Batch Loss: 0.19085389375686646
Epoch 1/3, Batch Loss: 0.0859127789735794
Epoch 1/3, Batch Loss: 0.13755768537521362
Epoch 1/3, Batch Loss: 0.21053242683410645
Epoch 1/3, Batch Loss: 0.08619531244039536
Epoch 1/3, Batch Loss: 0.10309810936450958
Epoch 1/3, Batch Loss: 0.36473047733306885
Epoch 1/3, Batch Loss: 0.43702593445777893
Epoch 1/3, Batch Loss: 0.17940373718738556
Epoch 1/3, 

Epoch 1/3, Batch Loss: 0.08106784522533417
Epoch 1/3, Batch Loss: 0.179088294506073
Epoch 1/3, Batch Loss: 0.3036099374294281
Epoch 1/3, Batch Loss: 0.16630443930625916
Epoch 1/3, Batch Loss: 0.6386078000068665
Epoch 1/3, Batch Loss: 0.06646113842725754
Epoch 1/3, Batch Loss: 0.20728904008865356
Epoch 1/3, Batch Loss: 0.05534375086426735
Epoch 1/3, Batch Loss: 0.4620824158191681
Epoch 1/3, Batch Loss: 0.15315496921539307
Epoch 1/3, Batch Loss: 0.31380099058151245
Epoch 1/3, Batch Loss: 0.21816520392894745
Epoch 1/3, Batch Loss: 0.05160021781921387
Epoch 1/3, Batch Loss: 0.1432933509349823
Epoch 1/3, Batch Loss: 0.5357885956764221
Epoch 1/3, Batch Loss: 0.12593701481819153
Epoch 1/3, Batch Loss: 0.23644250631332397
Epoch 1/3, Batch Loss: 0.45806679129600525
Epoch 1/3, Batch Loss: 0.2885107696056366
Epoch 1/3, Batch Loss: 0.18549449741840363
Epoch 1/3, Batch Loss: 0.5455300211906433
Epoch 1/3, Batch Loss: 0.32635363936424255
Epoch 1/3, Batch Loss: 0.22068078815937042
Epoch 1/3, Batch Los

Epoch 2/3, Batch Loss: 0.15170788764953613
Epoch 2/3, Batch Loss: 0.10967724770307541
Epoch 2/3, Batch Loss: 0.2485322654247284
Epoch 2/3, Batch Loss: 0.04481646046042442
Epoch 2/3, Batch Loss: 0.2972519099712372
Epoch 2/3, Batch Loss: 0.0846194252371788
Epoch 2/3, Batch Loss: 0.09786079078912735
Epoch 2/3, Batch Loss: 0.1530381590127945
Epoch 2/3, Batch Loss: 0.23187321424484253
Epoch 2/3, Batch Loss: 0.2000698447227478
Epoch 2/3, Batch Loss: 0.1834535300731659
Epoch 2/3, Batch Loss: 0.3041535019874573
Epoch 2/3, Batch Loss: 0.5935243368148804
Epoch 2/3, Batch Loss: 0.050836581736803055
Epoch 2/3, Batch Loss: 0.04656049609184265
Epoch 2/3, Batch Loss: 0.4061460793018341
Epoch 2/3, Batch Loss: 0.278975248336792
Epoch 2/3, Batch Loss: 0.11445492506027222
Epoch 2/3, Batch Loss: 0.2720101475715637
Epoch 2/3, Batch Loss: 0.10016107559204102
Epoch 2/3, Batch Loss: 0.22551530599594116
Epoch 2/3, Batch Loss: 0.15869256854057312
Epoch 2/3, Batch Loss: 0.0646262988448143
Epoch 2/3, Batch Loss: 

Epoch 2/3, Batch Loss: 0.06306705623865128
Epoch 2/3, Batch Loss: 0.1913781613111496
Epoch 2/3, Batch Loss: 0.03518502414226532
Epoch 2/3, Batch Loss: 0.03686470165848732
Epoch 2/3, Batch Loss: 0.15572302043437958
Epoch 2/3, Batch Loss: 0.014070551842451096
Epoch 2/3, Batch Loss: 0.04007376357913017
Epoch 2/3, Batch Loss: 0.17004109919071198
Epoch 2/3, Batch Loss: 0.11937101185321808
Epoch 2/3, Batch Loss: 0.1912258118391037
Epoch 2/3, Batch Loss: 0.297877699136734
Epoch 2/3, Batch Loss: 0.18270879983901978
Epoch 2/3, Batch Loss: 0.15345333516597748
Epoch 2/3, Batch Loss: 0.19950038194656372
Epoch 2/3, Batch Loss: 0.04060225188732147
Epoch 2/3, Batch Loss: 0.23105724155902863
Epoch 2/3, Batch Loss: 0.1061522364616394
Epoch 2/3, Batch Loss: 0.07017781585454941
Epoch 2/3, Batch Loss: 0.04251433163881302
Epoch 2/3, Batch Loss: 0.019462643191218376
Epoch 2/3, Batch Loss: 0.1017010360956192
Epoch 2/3, Batch Loss: 0.06128996983170509
Epoch 2/3, Batch Loss: 0.15990710258483887
Epoch 2/3, Batc

Epoch 2/3, Batch Loss: 0.10495423525571823
Epoch 2/3, Batch Loss: 0.42285919189453125
Epoch 2/3, Batch Loss: 0.05426153540611267
Epoch 2/3, Batch Loss: 0.10522540658712387
Epoch 2/3, Batch Loss: 0.09481445699930191
Epoch 2/3, Batch Loss: 0.7919904589653015
Epoch 2/3, Batch Loss: 0.4320421814918518
Epoch 2/3, Batch Loss: 0.039275024086236954
Epoch 2/3, Batch Loss: 0.3064698278903961
Epoch 2/3, Batch Loss: 0.2996165454387665
Epoch 2/3, Batch Loss: 0.03438685089349747
Epoch 2/3, Batch Loss: 0.07293163985013962
Epoch 2/3, Batch Loss: 0.2088630497455597
Epoch 2/3, Batch Loss: 0.2263355553150177
Epoch 2/3, Batch Loss: 0.05247066169977188
Epoch 2/3, Batch Loss: 0.03900463506579399
Epoch 2/3, Batch Loss: 0.031300801783800125
Epoch 2/3, Batch Loss: 0.11420908570289612
Epoch 2/3, Batch Loss: 0.29858049750328064
Epoch 2/3, Batch Loss: 0.32751572132110596
Epoch 2/3, Batch Loss: 0.19237275421619415
Epoch 2/3, Batch Loss: 0.338958740234375
Epoch 2/3, Batch Loss: 0.2798667848110199
Epoch 2/3, Batch L

Epoch 2/3, Batch Loss: 0.16468806564807892
Epoch 2/3, Batch Loss: 0.05846388638019562
Epoch 2/3, Batch Loss: 0.049055688083171844
Epoch 2/3, Batch Loss: 0.18046461045742035
Epoch 2/3, Batch Loss: 0.023377446457743645
Epoch 2/3, Batch Loss: 0.09666815400123596
Epoch 2/3, Batch Loss: 0.016341928392648697
Epoch 2/3, Batch Loss: 0.12363485246896744
Epoch 2/3, Batch Loss: 0.2837703227996826
Epoch 2/3, Batch Loss: 0.1121334433555603
Epoch 2/3, Batch Loss: 0.10100214183330536
Epoch 2/3, Batch Loss: 0.06703472882509232
Epoch 2/3, Batch Loss: 0.08073120564222336
Epoch 2/3, Batch Loss: 0.10881932079792023
Epoch 2/3, Batch Loss: 0.16287650167942047
Epoch 2/3, Batch Loss: 0.0397428423166275
Epoch 2/3, Batch Loss: 0.04289700463414192
Epoch 2/3, Batch Loss: 0.27320781350135803
Epoch 2/3, Batch Loss: 0.3599857687950134
Epoch 2/3, Batch Loss: 0.050059277564287186
Epoch 2/3, Batch Loss: 0.3368014097213745
Epoch 2/3, Batch Loss: 0.12382368743419647
Epoch 2/3, Batch Loss: 0.29893091320991516
Epoch 2/3, B

Epoch 2/3, Batch Loss: 0.22590070962905884
Epoch 2/3, Batch Loss: 0.1885070651769638
Epoch 2/3, Batch Loss: 0.0937766581773758
Epoch 2/3, Batch Loss: 0.12570108473300934
Epoch 2/3, Batch Loss: 0.04448791965842247
Epoch 2/3, Batch Loss: 0.21368902921676636
Epoch 2/3, Batch Loss: 0.2585199177265167
Epoch 2/3, Batch Loss: 0.06812117993831635
Epoch 2/3, Batch Loss: 0.09769314527511597
Epoch 2/3, Batch Loss: 0.6578382253646851
Epoch 2/3, Batch Loss: 0.18087758123874664
Epoch 2/3, Batch Loss: 0.1077093854546547
Epoch 2/3, Batch Loss: 0.06848012655973434
Epoch 2/3, Batch Loss: 0.41351085901260376
Epoch 2/3, Batch Loss: 0.1353464424610138
Epoch 2/3, Batch Loss: 0.056638024747371674
Epoch 2/3, Batch Loss: 0.42899852991104126
Epoch 2/3, Batch Loss: 0.04040967673063278
Epoch 2/3, Batch Loss: 0.06016048043966293
Epoch 2/3, Batch Loss: 0.17621032893657684
Epoch 2/3, Batch Loss: 0.37370064854621887
Epoch 2/3, Batch Loss: 0.07618667930364609
Epoch 2/3, Batch Loss: 0.24997900426387787
Epoch 2/3, Batch

Epoch 2/3, Batch Loss: 0.20560482144355774
Epoch 2/3, Batch Loss: 0.17980830371379852
Epoch 2/3, Batch Loss: 0.11041583120822906
Epoch 2/3, Batch Loss: 0.17848964035511017
Epoch 2/3, Batch Loss: 0.2553348243236542
Epoch 2/3, Batch Loss: 0.00408919295296073
Epoch 2/3, Batch Loss: 0.19669121503829956
Epoch 2/3, Batch Loss: 0.09623923897743225
Epoch 2/3, Batch Loss: 0.14951670169830322
Epoch 2/3, Batch Loss: 0.2521929144859314
Epoch 2/3, Batch Loss: 0.1588743031024933
Epoch 2/3, Batch Loss: 0.19666968286037445
Epoch 2/3, Batch Loss: 0.12403680384159088
Epoch 2/3, Batch Loss: 0.2655746340751648
Epoch 2/3, Batch Loss: 0.17067432403564453
Epoch 2/3, Batch Loss: 0.06360753625631332
Epoch 2/3, Batch Loss: 0.10242605209350586
Epoch 2/3, Batch Loss: 0.09004247188568115
Epoch 2/3, Batch Loss: 0.06444140523672104
Epoch 2/3, Batch Loss: 0.13591571152210236
Epoch 2/3, Batch Loss: 0.17749592661857605
Epoch 2/3, Batch Loss: 0.28248482942581177
Epoch 2/3, Batch Loss: 0.2514379620552063
Epoch 2/3, Batch

Epoch 2/3, Batch Loss: 0.304195374250412
Epoch 2/3, Batch Loss: 0.24082791805267334
Epoch 2/3, Batch Loss: 0.07631288468837738
Epoch 2/3, Batch Loss: 0.17408418655395508
Epoch 2/3, Batch Loss: 0.05873258411884308
Epoch 2/3, Batch Loss: 0.15685094892978668
Epoch 2/3, Batch Loss: 0.28682175278663635
Epoch 2/3, Batch Loss: 0.07207061350345612
Epoch 2/3, Batch Loss: 0.15392020344734192
Epoch 2/3, Batch Loss: 0.2835361659526825
Epoch 2/3, Batch Loss: 0.1720443069934845
Epoch 2/3, Batch Loss: 0.1084703579545021
Epoch 2/3, Batch Loss: 0.02210453897714615
Epoch 2/3, Batch Loss: 0.18775470554828644
Epoch 2/3, Batch Loss: 0.2563861310482025
Epoch 2/3, Batch Loss: 0.24787001311779022
Epoch 2/3, Batch Loss: 0.17760244011878967
Epoch 2/3, Batch Loss: 0.04484786465764046
Epoch 2/3, Batch Loss: 0.174873948097229
Epoch 2/3, Batch Loss: 0.07220510393381119
Epoch 2/3, Batch Loss: 0.18019895255565643
Epoch 2/3, Batch Loss: 0.45909857749938965
Epoch 2/3, Batch Loss: 0.29242831468582153
Epoch 2/3, Batch Lo

Epoch 2/3, Batch Loss: 0.474759578704834
Epoch 2/3, Batch Loss: 0.056741971522569656
Epoch 2/3, Batch Loss: 0.04118184372782707
Epoch 2/3, Batch Loss: 0.04528211057186127
Epoch 2/3, Batch Loss: 0.047498613595962524
Epoch 2/3, Batch Loss: 0.24528935551643372
Epoch 2/3, Batch Loss: 0.03998760133981705
Epoch 2/3, Batch Loss: 0.07953803241252899
Epoch 2/3, Batch Loss: 0.21285070478916168
Epoch 2/3, Batch Loss: 0.052559684962034225
Epoch 2/3, Batch Loss: 0.20956310629844666
Epoch 2/3, Batch Loss: 0.0775454193353653
Epoch 2/3, Batch Loss: 0.1699552983045578
Epoch 2/3, Batch Loss: 0.09928712248802185
Epoch 2/3, Batch Loss: 0.11764535307884216
Epoch 2/3, Batch Loss: 0.04476311802864075
Epoch 2/3, Batch Loss: 0.033154506236314774
Epoch 2/3, Batch Loss: 0.08908672630786896
Epoch 2/3, Batch Loss: 0.046713605523109436
Epoch 2/3, Batch Loss: 0.03306204080581665
Epoch 2/3, Batch Loss: 0.0464898981153965
Epoch 2/3, Batch Loss: 0.06493259221315384
Epoch 2/3, Batch Loss: 0.18110647797584534
Epoch 2/3, 

Epoch 2/3, Batch Loss: 0.07769845426082611
Epoch 2/3, Batch Loss: 0.06291283667087555
Epoch 2/3, Batch Loss: 0.1532992124557495
Epoch 2/3, Batch Loss: 0.2798765003681183
Epoch 2/3, Batch Loss: 0.4857304096221924
Epoch 2/3, Batch Loss: 0.12412754446268082
Epoch 2/3, Batch Loss: 0.05566595494747162
Epoch 2/3, Batch Loss: 0.3034651577472687
Epoch 2/3, Batch Loss: 0.053880512714385986
Epoch 2/3, Batch Loss: 0.14221113920211792
Epoch 2/3, Batch Loss: 0.05770060792565346
Epoch 2/3, Batch Loss: 0.11859751492738724
Epoch 2/3, Batch Loss: 0.4826183021068573
Epoch 2/3, Batch Loss: 0.14312878251075745
Epoch 2/3, Batch Loss: 0.09428227692842484
Epoch 2/3, Batch Loss: 0.11772482097148895
Epoch 2/3, Batch Loss: 0.07899418473243713
Epoch 2/3, Batch Loss: 0.6187300086021423
Epoch 2/3, Batch Loss: 0.09394074231386185
Epoch 2/3, Batch Loss: 0.10203035920858383
Epoch 2/3, Batch Loss: 0.25602537393569946
Epoch 2/3, Batch Loss: 0.1442563533782959
Epoch 2/3, Batch Loss: 0.07382815331220627
Epoch 2/3, Batch 

Epoch 2/3, Batch Loss: 0.30627143383026123
Epoch 2/3, Batch Loss: 0.4398517906665802
Epoch 2/3, Batch Loss: 0.023812955245375633
Epoch 2/3, Batch Loss: 0.08995430171489716
Epoch 2/3, Batch Loss: 0.1655624508857727
Epoch 2/3, Batch Loss: 0.1516711562871933
Epoch 2/3, Batch Loss: 0.31313556432724
Epoch 2/3, Batch Loss: 0.30484089255332947
Epoch 2/3, Batch Loss: 0.32287922501564026
Epoch 2/3, Batch Loss: 0.1654157191514969
Epoch 2/3, Batch Loss: 0.12307648360729218
Epoch 2/3, Batch Loss: 0.40892910957336426
Epoch 2/3, Batch Loss: 0.1604733020067215
Epoch 2/3, Batch Loss: 0.16534508764743805
Epoch 2/3, Batch Loss: 0.1591043770313263
Epoch 2/3, Batch Loss: 0.05513443052768707
Epoch 2/3, Batch Loss: 0.24444632232189178
Epoch 2/3, Batch Loss: 0.3017437756061554
Epoch 2/3, Batch Loss: 0.11354345083236694
Epoch 2/3, Batch Loss: 0.10847166180610657
Epoch 2/3, Batch Loss: 0.20110616087913513
Epoch 2/3, Batch Loss: 0.0669284462928772
Epoch 2/3, Batch Loss: 0.1533869355916977
Epoch 2/3, Batch Loss:

Epoch 2/3, Batch Loss: 0.3385812044143677
Epoch 2/3, Batch Loss: 0.052560072392225266
Epoch 2/3, Batch Loss: 0.03565671294927597
Epoch 2/3, Batch Loss: 0.05293760448694229
Epoch 2/3, Batch Loss: 0.47722166776657104
Epoch 2/3, Batch Loss: 0.17082449793815613
Epoch 2/3, Batch Loss: 0.13025248050689697
Epoch 2/3, Batch Loss: 0.37522226572036743
Epoch 2/3, Batch Loss: 0.20302188396453857
Epoch 2/3, Batch Loss: 0.4125438928604126
Epoch 2/3, Batch Loss: 0.29640328884124756
Epoch 2/3, Batch Loss: 0.1581559181213379
Epoch 2/3, Batch Loss: 0.03303808718919754
Epoch 2/3, Batch Loss: 0.266242653131485
Epoch 2/3, Batch Loss: 0.27555981278419495
Epoch 2/3, Batch Loss: 0.10131380707025528
Epoch 2/3, Batch Loss: 0.35093361139297485
Epoch 2/3, Batch Loss: 0.30322885513305664
Epoch 2/3, Batch Loss: 0.07338225096464157
Epoch 2/3, Batch Loss: 0.05540885776281357
Epoch 2/3, Batch Loss: 0.10808545351028442
Epoch 2/3, Batch Loss: 0.37738820910453796
Epoch 2/3, Batch Loss: 0.2898384630680084
Epoch 2/3, Batch

Epoch 2/3, Batch Loss: 0.02988392859697342
Epoch 2/3, Batch Loss: 0.297627329826355
Epoch 2/3, Batch Loss: 0.08514600992202759
Epoch 2/3, Batch Loss: 0.08701221644878387
Epoch 2/3, Batch Loss: 0.44807669520378113
Epoch 2/3, Batch Loss: 0.6626300811767578
Epoch 2/3, Batch Loss: 0.4688701629638672
Epoch 2/3, Batch Loss: 0.07076099514961243
Epoch 2/3, Batch Loss: 0.05312005430459976
Epoch 2/3, Batch Loss: 0.15890561044216156
Epoch 2/3, Batch Loss: 0.1453966200351715
Epoch 2/3, Batch Loss: 0.18249265849590302
Epoch 2/3, Batch Loss: 0.056903861463069916
Epoch 2/3, Batch Loss: 0.2907930910587311
Epoch 2/3, Batch Loss: 0.28918731212615967
Epoch 2/3, Batch Loss: 0.08189064264297485
Epoch 2/3, Batch Loss: 0.14959914982318878
Epoch 2/3, Batch Loss: 0.08212154358625412
Epoch 2/3, Batch Loss: 0.14648641645908356
Epoch 2/3, Batch Loss: 0.18435116112232208
Epoch 2/3, Batch Loss: 0.33176493644714355
Epoch 2/3, Batch Loss: 0.163613423705101
Epoch 2/3, Batch Loss: 0.27684423327445984
Epoch 2/3, Batch L

Epoch 2/3, Batch Loss: 0.17709708213806152
Epoch 2/3, Batch Loss: 0.13882270455360413
Epoch 2/3, Batch Loss: 0.2567654848098755
Epoch 2/3, Batch Loss: 0.4028589725494385
Epoch 2/3, Batch Loss: 0.1209813728928566
Epoch 2/3, Batch Loss: 0.11976467072963715
Epoch 2/3, Batch Loss: 0.06639190018177032
Epoch 2/3, Batch Loss: 0.03531878814101219
Epoch 2/3, Batch Loss: 0.06606345623731613
Epoch 2/3, Batch Loss: 0.31078168749809265
Epoch 2/3, Batch Loss: 0.30834585428237915
Epoch 2/3, Batch Loss: 0.02593693509697914
Epoch 2/3, Batch Loss: 0.11943098902702332
Epoch 2/3, Batch Loss: 0.1600235104560852
Epoch 2/3, Batch Loss: 0.05451131612062454
Epoch 2/3, Batch Loss: 0.2988545298576355
Epoch 2/3, Batch Loss: 0.18457871675491333
Epoch 2/3, Batch Loss: 0.24685785174369812
Epoch 2/3, Batch Loss: 0.5530730485916138
Epoch 2/3, Batch Loss: 0.17268143594264984
Epoch 2/3, Batch Loss: 0.21880097687244415
Epoch 2/3, Batch Loss: 0.01325040590018034
Epoch 2/3, Batch Loss: 0.5223776698112488
Epoch 2/3, Batch L

Epoch 2/3, Batch Loss: 0.15046779811382294
Epoch 2/3, Batch Loss: 0.18420106172561646
Epoch 2/3, Batch Loss: 0.19155317544937134
Epoch 2/3, Batch Loss: 0.05376341566443443
Epoch 2/3, Batch Loss: 0.12004931271076202
Epoch 2/3, Batch Loss: 0.05210278555750847
Epoch 2/3, Batch Loss: 0.034913334995508194
Epoch 2/3, Batch Loss: 0.1434101015329361
Epoch 2/3, Batch Loss: 0.1430560201406479
Epoch 2/3, Batch Loss: 0.05785084143280983
Epoch 2/3, Batch Loss: 0.3093417286872864
Epoch 2/3, Batch Loss: 0.06182044371962547
Epoch 2/3, Batch Loss: 0.042943187057971954
Epoch 2/3, Batch Loss: 0.1481555700302124
Epoch 2/3, Batch Loss: 0.09483938664197922
Epoch 2/3, Batch Loss: 0.0744890496134758
Epoch 2/3, Batch Loss: 0.02380528673529625
Epoch 2/3, Batch Loss: 0.04481734707951546
Epoch 2/3, Batch Loss: 0.2998954653739929
Epoch 2/3, Batch Loss: 0.1723528951406479
Epoch 2/3, Batch Loss: 0.2461404800415039
Epoch 2/3, Batch Loss: 0.3602875769138336
Epoch 2/3, Batch Loss: 0.17056904733181
Epoch 2/3, Batch Loss

Epoch 2/3, Batch Loss: 0.17372742295265198
Epoch 2/3, Batch Loss: 0.04050040245056152
Epoch 2/3, Batch Loss: 0.14918671548366547
Epoch 2/3, Batch Loss: 0.1623055338859558
Epoch 2/3, Batch Loss: 0.02248849719762802
Epoch 2/3, Batch Loss: 0.42504483461380005
Epoch 2/3, Batch Loss: 0.3137037456035614
Epoch 2/3, Batch Loss: 0.14382289350032806
Epoch 2/3, Batch Loss: 0.15949320793151855
Epoch 2/3, Batch Loss: 0.29408368468284607
Epoch 2/3, Batch Loss: 0.04304211959242821
Epoch 2/3, Batch Loss: 0.07735413312911987
Epoch 2/3, Batch Loss: 0.04931046813726425
Epoch 2/3, Batch Loss: 0.20517319440841675
Epoch 2/3, Batch Loss: 0.19351790845394135
Epoch 2/3, Batch Loss: 0.04838815703988075
Epoch 2/3, Batch Loss: 0.15253423154354095
Epoch 2/3, Batch Loss: 0.20186872780323029
Epoch 2/3, Batch Loss: 0.399615079164505
Epoch 2/3, Batch Loss: 0.537885308265686
Epoch 2/3, Batch Loss: 0.11506473273038864
Epoch 2/3, Batch Loss: 0.1690731942653656
Epoch 2/3, Batch Loss: 0.19235576689243317
Epoch 2/3, Batch L

Epoch 2/3, Batch Loss: 0.03129727765917778
Epoch 2/3, Batch Loss: 0.08042221516370773
Epoch 2/3, Batch Loss: 0.15045176446437836
Epoch 2/3, Batch Loss: 0.035968780517578125
Epoch 2/3, Batch Loss: 0.07839883863925934
Epoch 2/3, Batch Loss: 0.04462886229157448
Epoch 2/3, Batch Loss: 0.08397980779409409
Epoch 2/3, Batch Loss: 0.03331993520259857
Epoch 2/3, Batch Loss: 0.18735578656196594
Epoch 2/3, Batch Loss: 0.16110427677631378
Epoch 2/3, Batch Loss: 0.20398204028606415
Epoch 2/3, Batch Loss: 0.021881163120269775
Epoch 2/3, Batch Loss: 0.05033956840634346
Epoch 2/3, Batch Loss: 0.04471801966428757
Epoch 2/3, Batch Loss: 0.06212136521935463
Epoch 2/3, Batch Loss: 0.04582249000668526
Epoch 2/3, Batch Loss: 0.03772648051381111
Epoch 2/3, Batch Loss: 0.3655898869037628
Epoch 2/3, Batch Loss: 0.05882348120212555
Epoch 2/3, Batch Loss: 0.28057798743247986
Epoch 2/3, Batch Loss: 0.08712543547153473
Epoch 2/3, Batch Loss: 0.17414109408855438
Epoch 2/3, Batch Loss: 0.14194336533546448
Epoch 2/3,

Epoch 2/3, Batch Loss: 0.08902480453252792
Epoch 2/3, Batch Loss: 0.28400570154190063
Epoch 2/3, Batch Loss: 0.06417611241340637
Epoch 2/3, Batch Loss: 0.05067978426814079
Epoch 2/3, Batch Loss: 0.18753308057785034
Epoch 2/3, Batch Loss: 0.03138257935643196
Epoch 2/3, Batch Loss: 0.06138652190566063
Epoch 2/3, Batch Loss: 0.11374478787183762
Epoch 2/3, Batch Loss: 0.21775008738040924
Epoch 2/3, Batch Loss: 0.34207677841186523
Epoch 2/3, Batch Loss: 0.28458693623542786
Epoch 2/3, Batch Loss: 0.17347073554992676
Epoch 2/3, Batch Loss: 0.09473871439695358
Epoch 2/3, Batch Loss: 0.34213411808013916
Epoch 2/3, Batch Loss: 0.18095554411411285
Epoch 2/3, Batch Loss: 0.3492598831653595
Epoch 2/3, Batch Loss: 0.11438090354204178
Epoch 2/3, Batch Loss: 0.1556345522403717
Epoch 2/3, Batch Loss: 0.2033568173646927
Epoch 2/3, Batch Loss: 0.12982331216335297
Epoch 2/3, Batch Loss: 0.027725758031010628
Epoch 2/3, Batch Loss: 0.1711421012878418
Epoch 2/3, Batch Loss: 0.09495089948177338
Epoch 2/3, Bat

Epoch 2/3, Batch Loss: 0.2016679346561432
Epoch 2/3, Batch Loss: 0.3969368040561676
Epoch 2/3, Batch Loss: 0.09873189777135849
Epoch 2/3, Batch Loss: 0.04881379380822182
Epoch 2/3, Batch Loss: 0.14480215311050415
Epoch 2/3, Batch Loss: 0.10529342293739319
Epoch 2/3, Batch Loss: 0.17295758426189423
Epoch 2/3, Batch Loss: 0.15344677865505219
Epoch 2/3, Batch Loss: 0.1470508873462677
Epoch 2/3, Batch Loss: 0.24883343279361725
Epoch 2/3, Batch Loss: 0.16335557401180267
Epoch 2/3, Batch Loss: 0.18991927802562714
Epoch 2/3, Batch Loss: 0.12440308183431625
Epoch 2/3, Batch Loss: 0.3067205846309662
Epoch 2/3, Batch Loss: 0.19323301315307617
Epoch 2/3, Batch Loss: 0.13371768593788147
Epoch 2/3, Batch Loss: 0.20502535998821259
Epoch 2/3, Batch Loss: 0.4147745966911316
Epoch 2/3, Batch Loss: 0.13837820291519165
Epoch 2/3, Batch Loss: 0.15766045451164246
Epoch 2/3, Batch Loss: 0.1979013979434967
Epoch 2/3, Batch Loss: 0.12314676493406296
Epoch 2/3, Batch Loss: 0.19940423965454102
Epoch 2/3, Batch 

Epoch 3/3, Batch Loss: 0.02497708424925804
Epoch 3/3, Batch Loss: 0.19386854767799377
Epoch 3/3, Batch Loss: 0.31485211849212646
Epoch 3/3, Batch Loss: 0.3072972297668457
Epoch 3/3, Batch Loss: 0.25330406427383423
Epoch 3/3, Batch Loss: 0.03280198574066162
Epoch 3/3, Batch Loss: 0.1304149180650711
Epoch 3/3, Batch Loss: 0.04345587268471718
Epoch 3/3, Batch Loss: 0.25755774974823
Epoch 3/3, Batch Loss: 0.16868138313293457
Epoch 3/3, Batch Loss: 0.2517872154712677
Epoch 3/3, Batch Loss: 0.24352872371673584
Epoch 3/3, Batch Loss: 0.23362821340560913
Epoch 3/3, Batch Loss: 0.20010924339294434
Epoch 3/3, Batch Loss: 0.170835942029953
Epoch 3/3, Batch Loss: 0.09625420719385147
Epoch 3/3, Batch Loss: 0.22589154541492462
Epoch 3/3, Batch Loss: 0.18172062933444977
Epoch 3/3, Batch Loss: 0.13749855756759644
Epoch 3/3, Batch Loss: 0.2578474283218384
Epoch 3/3, Batch Loss: 0.10221821069717407
Epoch 3/3, Batch Loss: 0.060637328773736954
Epoch 3/3, Batch Loss: 0.26772409677505493
Epoch 3/3, Batch Lo

Epoch 3/3, Batch Loss: 0.1622689813375473
Epoch 3/3, Batch Loss: 0.322811096906662
Epoch 3/3, Batch Loss: 0.15403223037719727
Epoch 3/3, Batch Loss: 0.1026073694229126
Epoch 3/3, Batch Loss: 0.06794936209917068
Epoch 3/3, Batch Loss: 0.11845482140779495
Epoch 3/3, Batch Loss: 0.15066729485988617
Epoch 3/3, Batch Loss: 0.15188679099082947
Epoch 3/3, Batch Loss: 0.19953718781471252
Epoch 3/3, Batch Loss: 0.08802354335784912
Epoch 3/3, Batch Loss: 0.2710370719432831
Epoch 3/3, Batch Loss: 0.13184009492397308
Epoch 3/3, Batch Loss: 0.09056641906499863
Epoch 3/3, Batch Loss: 0.0600300133228302
Epoch 3/3, Batch Loss: 0.11071682721376419
Epoch 3/3, Batch Loss: 0.1855199784040451
Epoch 3/3, Batch Loss: 0.4054906368255615
Epoch 3/3, Batch Loss: 0.05004260689020157
Epoch 3/3, Batch Loss: 0.2888527512550354
Epoch 3/3, Batch Loss: 0.17252132296562195
Epoch 3/3, Batch Loss: 0.03115544468164444
Epoch 3/3, Batch Loss: 0.15198197960853577
Epoch 3/3, Batch Loss: 0.15700320899486542
Epoch 3/3, Batch Los

Epoch 3/3, Batch Loss: 0.06344339996576309
Epoch 3/3, Batch Loss: 0.21942362189292908
Epoch 3/3, Batch Loss: 0.18878620862960815
Epoch 3/3, Batch Loss: 0.07720118761062622
Epoch 3/3, Batch Loss: 0.054874684661626816
Epoch 3/3, Batch Loss: 0.0308622308075428
Epoch 3/3, Batch Loss: 0.05422217771410942
Epoch 3/3, Batch Loss: 0.043958764523267746
Epoch 3/3, Batch Loss: 0.2538749575614929
Epoch 3/3, Batch Loss: 0.14816416800022125
Epoch 3/3, Batch Loss: 0.19658342003822327
Epoch 3/3, Batch Loss: 0.26769256591796875
Epoch 3/3, Batch Loss: 0.11043262481689453
Epoch 3/3, Batch Loss: 0.3688202202320099
Epoch 3/3, Batch Loss: 0.15854740142822266
Epoch 3/3, Batch Loss: 0.2093367725610733
Epoch 3/3, Batch Loss: 0.0025331429205834866
Epoch 3/3, Batch Loss: 0.15002931654453278
Epoch 3/3, Batch Loss: 0.16774336993694305
Epoch 3/3, Batch Loss: 0.22119678556919098
Epoch 3/3, Batch Loss: 0.1318565309047699
Epoch 3/3, Batch Loss: 0.27656200528144836
Epoch 3/3, Batch Loss: 0.03846113756299019
Epoch 3/3, B

Epoch 3/3, Batch Loss: 0.06191627308726311
Epoch 3/3, Batch Loss: 0.13773232698440552
Epoch 3/3, Batch Loss: 0.1030721515417099
Epoch 3/3, Batch Loss: 0.3512876033782959
Epoch 3/3, Batch Loss: 0.16122585535049438
Epoch 3/3, Batch Loss: 0.12822243571281433
Epoch 3/3, Batch Loss: 0.05500463396310806
Epoch 3/3, Batch Loss: 0.1402231603860855
Epoch 3/3, Batch Loss: 0.056836120784282684
Epoch 3/3, Batch Loss: 0.05010709539055824
Epoch 3/3, Batch Loss: 0.21465523540973663
Epoch 3/3, Batch Loss: 0.17037837207317352
Epoch 3/3, Batch Loss: 0.0930279865860939
Epoch 3/3, Batch Loss: 0.17677435278892517
Epoch 3/3, Batch Loss: 0.1945526897907257
Epoch 3/3, Batch Loss: 0.08374932408332825
Epoch 3/3, Batch Loss: 0.06652145087718964
Epoch 3/3, Batch Loss: 0.1453102082014084
Epoch 3/3, Batch Loss: 0.15830419957637787
Epoch 3/3, Batch Loss: 0.11464329808950424
Epoch 3/3, Batch Loss: 0.1489279419183731
Epoch 3/3, Batch Loss: 0.23408342897891998
Epoch 3/3, Batch Loss: 0.05627048760652542
Epoch 3/3, Batch 

Epoch 3/3, Batch Loss: 0.039381228387355804
Epoch 3/3, Batch Loss: 0.4391709566116333
Epoch 3/3, Batch Loss: 0.2994425296783447
Epoch 3/3, Batch Loss: 0.02891438454389572
Epoch 3/3, Batch Loss: 0.05694285035133362
Epoch 3/3, Batch Loss: 0.3542550802230835
Epoch 3/3, Batch Loss: 0.07346002012491226
Epoch 3/3, Batch Loss: 0.11382041126489639
Epoch 3/3, Batch Loss: 0.4353012442588806
Epoch 3/3, Batch Loss: 0.07786086946725845
Epoch 3/3, Batch Loss: 0.15995164215564728
Epoch 3/3, Batch Loss: 0.20639494061470032
Epoch 3/3, Batch Loss: 0.04287930577993393
Epoch 3/3, Batch Loss: 0.059966713190078735
Epoch 3/3, Batch Loss: 0.15422847867012024
Epoch 3/3, Batch Loss: 0.3158389627933502
Epoch 3/3, Batch Loss: 0.17686836421489716
Epoch 3/3, Batch Loss: 0.36609187722206116
Epoch 3/3, Batch Loss: 0.16854634881019592
Epoch 3/3, Batch Loss: 0.18210133910179138
Epoch 3/3, Batch Loss: 0.31064993143081665
Epoch 3/3, Batch Loss: 0.2955982983112335
Epoch 3/3, Batch Loss: 0.07726722955703735
Epoch 3/3, Batc

Epoch 3/3, Batch Loss: 0.23881472647190094
Epoch 3/3, Batch Loss: 0.18443167209625244
Epoch 3/3, Batch Loss: 0.05739232152700424
Epoch 3/3, Batch Loss: 0.3171263337135315
Epoch 3/3, Batch Loss: 0.08077041804790497
Epoch 3/3, Batch Loss: 0.07095123082399368
Epoch 3/3, Batch Loss: 0.2609400749206543
Epoch 3/3, Batch Loss: 0.05908804014325142
Epoch 3/3, Batch Loss: 0.060590408742427826
Epoch 3/3, Batch Loss: 0.3922383487224579
Epoch 3/3, Batch Loss: 0.16386696696281433
Epoch 3/3, Batch Loss: 0.31625428795814514
Epoch 3/3, Batch Loss: 0.0846882238984108
Epoch 3/3, Batch Loss: 0.04676765203475952
Epoch 3/3, Batch Loss: 0.13056562840938568
Epoch 3/3, Batch Loss: 0.30822470784187317
Epoch 3/3, Batch Loss: 0.07365872710943222
Epoch 3/3, Batch Loss: 0.29837629199028015
Epoch 3/3, Batch Loss: 0.3868226706981659
Epoch 3/3, Batch Loss: 0.1790061742067337
Epoch 3/3, Batch Loss: 0.19321218132972717
Epoch 3/3, Batch Loss: 0.13190673291683197
Epoch 3/3, Batch Loss: 0.1673877239227295
Epoch 3/3, Batch 

Epoch 3/3, Batch Loss: 0.2104659080505371
Epoch 3/3, Batch Loss: 0.055079493671655655
Epoch 3/3, Batch Loss: 0.12329281866550446
Epoch 3/3, Batch Loss: 0.18355950713157654
Epoch 3/3, Batch Loss: 0.3433716297149658
Epoch 3/3, Batch Loss: 0.01774933561682701
Epoch 3/3, Batch Loss: 0.03669134899973869
Epoch 3/3, Batch Loss: 0.031288545578718185
Epoch 3/3, Batch Loss: 0.03790871798992157
Epoch 3/3, Batch Loss: 0.6886960864067078
Epoch 3/3, Batch Loss: 0.035198550671339035
Epoch 3/3, Batch Loss: 0.3393341898918152
Epoch 3/3, Batch Loss: 0.018181147053837776
Epoch 3/3, Batch Loss: 0.5047662854194641
Epoch 3/3, Batch Loss: 0.049138180911540985
Epoch 3/3, Batch Loss: 0.16967247426509857
Epoch 3/3, Batch Loss: 0.20841585099697113
Epoch 3/3, Batch Loss: 0.19308412075042725
Epoch 3/3, Batch Loss: 0.1712159961462021
Epoch 3/3, Batch Loss: 0.19298483431339264
Epoch 3/3, Batch Loss: 0.1608845442533493
Epoch 3/3, Batch Loss: 0.3768847584724426
Epoch 3/3, Batch Loss: 0.04425230622291565
Epoch 3/3, Bat

Epoch 3/3, Batch Loss: 0.03620085120201111
Epoch 3/3, Batch Loss: 0.184907004237175
Epoch 3/3, Batch Loss: 0.06322787702083588
Epoch 3/3, Batch Loss: 0.12155415862798691
Epoch 3/3, Batch Loss: 0.0898573249578476
Epoch 3/3, Batch Loss: 0.2040964961051941
Epoch 3/3, Batch Loss: 0.10421166568994522
Epoch 3/3, Batch Loss: 0.0393628366291523
Epoch 3/3, Batch Loss: 0.20702676475048065
Epoch 3/3, Batch Loss: 0.07313069701194763
Epoch 3/3, Batch Loss: 0.07531886547803879
Epoch 3/3, Batch Loss: 0.06038388982415199
Epoch 3/3, Batch Loss: 0.02515413798391819
Epoch 3/3, Batch Loss: 0.16399091482162476
Epoch 3/3, Batch Loss: 0.20929674804210663
Epoch 3/3, Batch Loss: 0.02776307612657547
Epoch 3/3, Batch Loss: 0.13663512468338013
Epoch 3/3, Batch Loss: 0.04741569980978966
Epoch 3/3, Batch Loss: 0.05503833666443825
Epoch 3/3, Batch Loss: 0.2625541388988495
Epoch 3/3, Batch Loss: 0.053443584591150284
Epoch 3/3, Batch Loss: 0.14316830039024353
Epoch 3/3, Batch Loss: 0.13624608516693115
Epoch 3/3, Batch

Epoch 3/3, Batch Loss: 0.3078761100769043
Epoch 3/3, Batch Loss: 0.04744509607553482
Epoch 3/3, Batch Loss: 0.04686727002263069
Epoch 3/3, Batch Loss: 0.38523614406585693
Epoch 3/3, Batch Loss: 0.06156197935342789
Epoch 3/3, Batch Loss: 0.06086263060569763
Epoch 3/3, Batch Loss: 0.2879902124404907
Epoch 3/3, Batch Loss: 0.10051199793815613
Epoch 3/3, Batch Loss: 0.023988449946045876
Epoch 3/3, Batch Loss: 0.037096090614795685
Epoch 3/3, Batch Loss: 0.035718075931072235
Epoch 3/3, Batch Loss: 0.12253665179014206
Epoch 3/3, Batch Loss: 0.05080685764551163
Epoch 3/3, Batch Loss: 0.1394435465335846
Epoch 3/3, Batch Loss: 0.2854834496974945
Epoch 3/3, Batch Loss: 0.1952107697725296
Epoch 3/3, Batch Loss: 0.2146567702293396
Epoch 3/3, Batch Loss: 0.30161529779434204
Epoch 3/3, Batch Loss: 0.19048085808753967
Epoch 3/3, Batch Loss: 0.03024361841380596
Epoch 3/3, Batch Loss: 0.19331541657447815
Epoch 3/3, Batch Loss: 0.0353815071284771
Epoch 3/3, Batch Loss: 0.2689090967178345
Epoch 3/3, Batch

Epoch 3/3, Batch Loss: 0.2826487421989441
Epoch 3/3, Batch Loss: 0.14851531386375427
Epoch 3/3, Batch Loss: 0.04219997301697731
Epoch 3/3, Batch Loss: 0.14531221985816956
Epoch 3/3, Batch Loss: 0.0756269320845604
Epoch 3/3, Batch Loss: 0.0798199400305748
Epoch 3/3, Batch Loss: 0.05707397311925888
Epoch 3/3, Batch Loss: 0.058272216469049454
Epoch 3/3, Batch Loss: 0.24354660511016846
Epoch 3/3, Batch Loss: 0.23071855306625366
Epoch 3/3, Batch Loss: 0.18644225597381592
Epoch 3/3, Batch Loss: 0.41196122765541077
Epoch 3/3, Batch Loss: 0.03843764215707779
Epoch 3/3, Batch Loss: 0.032007452100515366
Epoch 3/3, Batch Loss: 0.3039281368255615
Epoch 3/3, Batch Loss: 0.04303795099258423
Epoch 3/3, Batch Loss: 0.12848268449306488
Epoch 3/3, Batch Loss: 0.16609002649784088
Epoch 3/3, Batch Loss: 0.08333119004964828
Epoch 3/3, Batch Loss: 0.2762845754623413
Epoch 3/3, Batch Loss: 0.1866346150636673
Epoch 3/3, Batch Loss: 0.4180142879486084
Epoch 3/3, Batch Loss: 0.06823822855949402
Epoch 3/3, Batch

Epoch 3/3, Batch Loss: 0.10911989957094193
Epoch 3/3, Batch Loss: 0.09448906779289246
Epoch 3/3, Batch Loss: 0.14648129045963287
Epoch 3/3, Batch Loss: 0.09189967811107635
Epoch 3/3, Batch Loss: 0.3271927535533905
Epoch 3/3, Batch Loss: 0.010075953789055347
Epoch 3/3, Batch Loss: 0.045431219041347504
Epoch 3/3, Batch Loss: 0.3411840498447418
Epoch 3/3, Batch Loss: 0.07277920842170715
Epoch 3/3, Batch Loss: 0.027415182441473007
Epoch 3/3, Batch Loss: 0.03936365246772766
Epoch 3/3, Batch Loss: 0.3223458528518677
Epoch 3/3, Batch Loss: 0.07200649380683899
Epoch 3/3, Batch Loss: 0.5148178339004517
Epoch 3/3, Batch Loss: 0.07418740540742874
Epoch 3/3, Batch Loss: 0.18153047561645508
Epoch 3/3, Batch Loss: 0.23816873133182526
Epoch 3/3, Batch Loss: 0.038371920585632324
Epoch 3/3, Batch Loss: 0.12185511738061905
Epoch 3/3, Batch Loss: 0.17473119497299194
Epoch 3/3, Batch Loss: 0.3668544590473175
Epoch 3/3, Batch Loss: 0.062803253531456
Epoch 3/3, Batch Loss: 0.17659085988998413
Epoch 3/3, Bat

Epoch 3/3, Batch Loss: 0.04574360325932503
Epoch 3/3, Batch Loss: 0.07473892718553543
Epoch 3/3, Batch Loss: 0.06218763068318367
Epoch 3/3, Batch Loss: 0.09023596346378326
Epoch 3/3, Batch Loss: 0.29370418190956116
Epoch 3/3, Batch Loss: 0.0952245220541954
Epoch 3/3, Batch Loss: 0.1943105161190033
Epoch 3/3, Batch Loss: 0.1899283081293106
Epoch 3/3, Batch Loss: 0.23755250871181488
Epoch 3/3, Batch Loss: 0.3212462067604065
Epoch 3/3, Batch Loss: 0.06337855756282806
Epoch 3/3, Batch Loss: 0.05776616185903549
Epoch 3/3, Batch Loss: 0.0475875623524189
Epoch 3/3, Batch Loss: 0.31158727407455444
Epoch 3/3, Batch Loss: 0.18873313069343567
Epoch 3/3, Batch Loss: 0.29546821117401123
Epoch 3/3, Batch Loss: 0.3021143078804016
Epoch 3/3, Batch Loss: 0.03531273454427719
Epoch 3/3, Batch Loss: 0.040080729871988297
Epoch 3/3, Batch Loss: 0.07495639473199844
Epoch 3/3, Batch Loss: 0.17794303596019745
Epoch 3/3, Batch Loss: 0.1618076115846634
Epoch 3/3, Batch Loss: 0.2744868993759155
Epoch 3/3, Batch L

Epoch 3/3, Batch Loss: 0.12827832996845245
Epoch 3/3, Batch Loss: 0.177622988820076
Epoch 3/3, Batch Loss: 0.03310970962047577
Epoch 3/3, Batch Loss: 0.0784197449684143
Epoch 3/3, Batch Loss: 0.39039915800094604
Epoch 3/3, Batch Loss: 0.47896531224250793
Epoch 3/3, Batch Loss: 0.18567229807376862
Epoch 3/3, Batch Loss: 0.09958261996507645
Epoch 3/3, Batch Loss: 0.1160217672586441
Epoch 3/3, Batch Loss: 0.16968859732151031
Epoch 3/3, Batch Loss: 0.20042477548122406
Epoch 3/3, Batch Loss: 0.01988375186920166
Epoch 3/3, Batch Loss: 0.04056260734796524
Epoch 3/3, Batch Loss: 0.07019909471273422
Epoch 3/3, Batch Loss: 0.4419177174568176
Epoch 3/3, Batch Loss: 0.0759660005569458
Epoch 3/3, Batch Loss: 0.09855903685092926
Epoch 3/3, Batch Loss: 0.07342477887868881
Epoch 3/3, Batch Loss: 0.07941167801618576
Epoch 3/3, Batch Loss: 0.04079863429069519
Epoch 3/3, Batch Loss: 0.1997780054807663
Epoch 3/3, Batch Loss: 0.07172474265098572
Epoch 3/3, Batch Loss: 0.0395829938352108
Epoch 3/3, Batch Lo

Epoch 3/3, Batch Loss: 0.16430127620697021
Epoch 3/3, Batch Loss: 0.1325996369123459
Epoch 3/3, Batch Loss: 0.19143006205558777
Epoch 3/3, Batch Loss: 0.20085807144641876
Epoch 3/3, Batch Loss: 0.19907377660274506
Epoch 3/3, Batch Loss: 0.27888059616088867
Epoch 3/3, Batch Loss: 0.2188510298728943
Epoch 3/3, Batch Loss: 0.03513288497924805
Epoch 3/3, Batch Loss: 0.11858030408620834
Epoch 3/3, Batch Loss: 0.3664565980434418
Epoch 3/3, Batch Loss: 0.1541798710823059
Epoch 3/3, Batch Loss: 0.09763067960739136
Epoch 3/3, Batch Loss: 0.09862413257360458
Epoch 3/3, Batch Loss: 0.20892837643623352
Epoch 3/3, Batch Loss: 0.12782315909862518
Epoch 3/3, Batch Loss: 0.09614238142967224
Epoch 3/3, Batch Loss: 0.17191259562969208
Epoch 3/3, Batch Loss: 0.0503455214202404
Epoch 3/3, Batch Loss: 0.08436547964811325
Epoch 3/3, Batch Loss: 0.20860685408115387
Epoch 3/3, Batch Loss: 0.37402036786079407
Epoch 3/3, Batch Loss: 0.038687434047460556
Epoch 3/3, Batch Loss: 0.06344728171825409
Epoch 3/3, Batc

Epoch 3/3, Batch Loss: 0.34711164236068726
Epoch 3/3, Batch Loss: 0.15629467368125916
Epoch 3/3, Batch Loss: 0.06774136424064636
Epoch 3/3, Batch Loss: 0.17540863156318665
Epoch 3/3, Batch Loss: 0.2309858202934265
Epoch 3/3, Batch Loss: 0.06871016323566437
Epoch 3/3, Batch Loss: 0.17036493122577667
Epoch 3/3, Batch Loss: 0.2876114845275879
Epoch 3/3, Batch Loss: 0.17037798464298248
Epoch 3/3, Batch Loss: 0.36086976528167725
Epoch 3/3, Batch Loss: 0.0929524376988411
Epoch 3/3, Batch Loss: 0.09265431016683578
Epoch 3/3, Batch Loss: 0.14028358459472656
Epoch 3/3, Batch Loss: 0.1486481875181198
Epoch 3/3, Batch Loss: 0.15818604826927185
Epoch 3/3, Batch Loss: 0.14402633905410767
Epoch 3/3, Batch Loss: 0.3162291646003723
Epoch 3/3, Batch Loss: 0.4148951768875122
Epoch 3/3, Batch Loss: 0.27749499678611755
Epoch 3/3, Batch Loss: 0.17490699887275696
Epoch 3/3, Batch Loss: 0.14623236656188965
Epoch 3/3, Batch Loss: 0.07228325307369232
Epoch 3/3, Batch Loss: 0.13195982575416565
Epoch 3/3, Batch 

Epoch 3/3, Batch Loss: 0.1679651141166687
Epoch 3/3, Batch Loss: 0.07242653518915176
Epoch 3/3, Batch Loss: 0.44177213311195374
Epoch 3/3, Batch Loss: 0.06924211978912354
Epoch 3/3, Batch Loss: 0.054864972829818726
Epoch 3/3, Batch Loss: 0.341724693775177
Epoch 3/3, Batch Loss: 0.2540881931781769
Epoch 3/3, Batch Loss: 0.2431134581565857
Epoch 3/3, Batch Loss: 0.14289723336696625
Epoch 3/3, Batch Loss: 0.20378534495830536
Epoch 3/3, Batch Loss: 0.6234591007232666
Epoch 3/3, Batch Loss: 0.1519833654165268
Epoch 3/3, Batch Loss: 0.09896190464496613
Epoch 3/3, Batch Loss: 0.0745716467499733
Epoch 3/3, Batch Loss: 0.12979517877101898
Epoch 3/3, Batch Loss: 0.1994321197271347
Epoch 3/3, Batch Loss: 0.04654661566019058
Epoch 3/3, Batch Loss: 0.14255373179912567
Epoch 3/3, Batch Loss: 0.08031487464904785
Epoch 3/3, Batch Loss: 0.15826934576034546
Epoch 3/3, Batch Loss: 0.224015012383461
Epoch 3/3, Batch Loss: 0.18246324360370636
Epoch 3/3, Batch Loss: 0.07443448901176453
Epoch 3/3, Batch Loss

Epoch 3/3, Batch Loss: 0.030734319239854813
Epoch 3/3, Batch Loss: 0.1911545991897583
Epoch 3/3, Batch Loss: 0.07540176808834076
Epoch 3/3, Batch Loss: 0.1665881872177124
Epoch 3/3, Batch Loss: 0.4484604001045227
Epoch 3/3, Batch Loss: 0.10307525098323822
Epoch 3/3, Batch Loss: 0.20333801209926605
Epoch 3/3, Batch Loss: 0.09012465178966522
Epoch 3/3, Batch Loss: 0.07120656222105026
Epoch 3/3, Batch Loss: 0.19247651100158691
Epoch 3/3, Batch Loss: 0.09406279027462006
Epoch 3/3, Batch Loss: 0.2052002102136612
Epoch 3/3, Batch Loss: 0.25156041979789734
Epoch 3/3, Batch Loss: 0.1411043405532837
Epoch 3/3, Batch Loss: 0.1604359745979309
Epoch 3/3, Batch Loss: 0.2366349697113037
Epoch 3/3, Batch Loss: 0.16271160542964935
Epoch 3/3, Batch Loss: 0.06695741415023804
Epoch 3/3, Batch Loss: 0.5400360822677612
Epoch 3/3, Batch Loss: 0.22339898347854614
Epoch 3/3, Batch Loss: 0.2953013777732849
Epoch 3/3, Batch Loss: 0.24248866736888885
Epoch 3/3, Batch Loss: 0.20203529298305511
Epoch 3/3, Batch Lo

In [8]:
# Fine-tune용
import os
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, TensorDataset
import torch
from sklearn.model_selection import train_test_split

class CustomBertForSequenceClassification(BertForSequenceClassification):
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        labels=None,
        output_hidden_states=True
    ):
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            labels=labels,
            output_hidden_states=output_hidden_states
        )
        logits = outputs.logits
        hidden_states = outputs.hidden_states[-5]  # n번째 레이어의 hidden states를 반환합니다.
        loss = outputs.loss
        return logits, loss, hidden_states

# 데이터 로드 및 전처리
data_A = pd.read_csv("output3.csv")  # data set A 파일명에 맞게 수정
data_B = pd.read_csv("infected.csv")  # data set B 파일명에 맞게 수정
# 모델 불러오는 경로
model_path = "Pre-trained.pt"
# 모델 저장경로
model_path2 = "Fine-tuned.pt"

# X_train, Y_train 생성
X_train = []
Y_train = []

for index, row in data_A.iterrows():  # 중복 제거를 하지 않고 원본 데이터 사용
    patient_id = row["ID"]
    patient_info = [str(row[column]) for column in data_A.columns if column != "ID" and column != "DESCRIPTION"]
    symptoms = ", ".join(data_A[data_A["ID"] == patient_id]["DESCRIPTION"].tolist())
    combined_info = ", ".join(patient_info) + ", " + symptoms
    X_train.append(combined_info)
    if patient_id in data_B.values:
        Y_train.append(1)
    else:
        Y_train.append(0)

print("X_train\n", X_train[:10])
print("Y_train\n", Y_train[:10])
        
# BERT 토크나이저 및 모델 로드
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# 모델이 이미 저장되어 있는지 확인하고, 저장된 모델이 있으면 불러오고 없으면 새로운 모델 생성
if os.path.exists(model_path):
    # 저장된 모델이 있을 경우 불러오기
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    model.load_state_dict(torch.load(model_path))
    print("Pre-train model loaded.")
else:
    # 저장된 모델이 없을 경우 새로운 모델 생성
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    print("New model generated.")

# 입력 데이터를 BERT의 입력 형식으로 변환
max_len = 128  # 입력 시퀀스의 최대 길이

input_ids = []
attention_masks = []

for info in X_train:
    encoded_dict = tokenizer.encode_plus(
                        info,                         # 환자 정보 및 증상
                        add_special_tokens = True,    # [CLS], [SEP] 토큰 추가
                        max_length = max_len,         # 최대 길이 지정
                        pad_to_max_length = True,     # 패딩을 추가하여 최대 길이로 맞춤
                        return_attention_mask = True, # 어텐션 마스크 생성
                        return_tensors = 'pt',        # PyTorch 텐서로 반환
                   )
    
    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(Y_train)

# 데이터셋 및 데이터로더 생성
dataset = TensorDataset(input_ids, attention_masks, labels)
train_size = 0.8
train_dataset, val_dataset = train_test_split(dataset, test_size=1-train_size, random_state=42)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True)

# GPU 사용 가능 여부 확인
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

# 모델을 GPU로 이동
model.to(device)

# 옵티마이저 및 학습률 설정
# 기본 학습률 : 2e-6
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-6)

# 에폭 설정
epochs = 3

# 학습 루프
hidden_states_list = []  # 모든 에폭에 대한 hidden state를 저장할 리스트
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in train_dataloader:
        batch = tuple(t.to(device) for t in batch)
        inputs = {'input_ids': batch[0],
                  'attention_mask': batch[1],
                  'labels': batch[2]}
        optimizer.zero_grad()
        outputs = model(**inputs)
        loss = outputs[1]  # loss가 outputs의 두 번째 값입니다.
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch + 1}/{epochs}, Batch Loss: {loss.item()}')
        # hidden state를 저장합니다.
        #hidden_states = outputs[2]
        #hidden_states_list.append(hidden_states)

    avg_train_loss = total_loss / len(train_dataloader)
    print(f'Epoch {epoch + 1}/{epochs}, Average Training Loss: {avg_train_loss}')

# 모든 에폭에 대한 hidden state를 합쳐서 CSV 파일로 저장합니다.
#hidden_states_concat = torch.cat(hidden_states_list, dim=0)
#hidden_states_concat = hidden_states_concat[:, 0, :].cpu().detach().numpy()
#hidden_states_df = pd.DataFrame(hidden_states_concat)
#hidden_states_df.to_csv("hidden_states_all_epochs.csv", index=False)

# 모델 저장
torch.save(model.state_dict(), model_path2)

# 모델 평가
model.eval()
val_accuracy = 0
for batch in val_dataloader:
    batch = tuple(t.to(device) for t in batch)
    inputs = {'input_ids': batch[0],
              'attention_mask': batch[1],
              'labels': batch[2]}
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs[0]  # logits가 outputs의 첫 번째 값입니다.
    logits = logits.detach().cpu().numpy()
    label_ids = inputs['labels'].cpu().numpy()
    val_accuracy += (logits.argmax(axis=1) == label_ids).mean().item()

print(f'Validation Accuracy: {val_accuracy / len(val_dataloader)}')


X_train
 ['5/11/1990, nan, 999-50-6899, S99965785, X23414632X, Mrs., Golden321, Berge125, nan, Becker968, M, white, nonhispanic, F, Saugus  Massachusetts  US, 900 Kihn Loaf Apt 75, Foxborough, Massachusetts, Norfolk County, nan, 42.07605525, -71.28353458, 645780.38, 7741.1, 5/10/2003, nan, 561d9445-df64-4fb9-a38a-37ccb670317e, 74400008.0, Appendicitis, Rupture of appendix', '10/5/1957, nan, 999-54-2768, S99932091, X75665632X, Mrs., Nevada145, Brakus656, nan, Schneider199, M, black, nonhispanic, F, Conway  Massachusetts  US, 844 Harris Ferry, Boston, Massachusetts, Suffolk County, 2130.0, 42.32797446, -71.15729167, 1352661.73, 14479.55, 3/8/2020, 3/16/2020, 92a64d19-1760-4344-b109-447bf15d70c1, 65710008.0, Acute respiratory failure (disorder), Hypoxemia (disorder)', '11/26/1969, nan, 999-84-2178, S99933880, X42152198X, Mrs., Tequila897, Ward668, nan, Larkin917, M, black, nonhispanic, F, Woburn  Massachusetts  US, 1083 Greenholt Lane, Athol, Massachusetts, Worcester County, 1331.0, 42.54

Some weights of CustomBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Pre-train model loaded.
True
Epoch 1/3, Batch Loss: 0.35644716024398804
Epoch 1/3, Batch Loss: 0.06446950137615204
Epoch 1/3, Batch Loss: 0.49225953221321106
Epoch 1/3, Batch Loss: 0.04072669893503189
Epoch 1/3, Batch Loss: 0.024975890293717384
Epoch 1/3, Batch Loss: 0.46829500794410706
Epoch 1/3, Batch Loss: 0.267522931098938
Epoch 1/3, Batch Loss: 0.05015619471669197
Epoch 1/3, Batch Loss: 0.31322547793388367
Epoch 1/3, Batch Loss: 0.11723592132329941
Epoch 1/3, Batch Loss: 0.027123499661684036
Epoch 1/3, Batch Loss: 0.31227344274520874
Epoch 1/3, Batch Loss: 0.1712123453617096
Epoch 1/3, Batch Loss: 0.1577252745628357
Epoch 1/3, Batch Loss: 0.09736476093530655
Epoch 1/3, Batch Loss: 0.1891738623380661
Epoch 1/3, Batch Loss: 0.28676292300224304
Epoch 1/3, Batch Loss: 0.08035378903150558
Epoch 1/3, Batch Loss: 0.20341376960277557
Epoch 1/3, Batch Loss: 0.14915987849235535
Epoch 1/3, Batch Loss: 0.2702379524707794
Epoch 1/3, Batch Loss: 0.07132475078105927
Epoch 1/3, Batch Loss: 0.2160

Epoch 1/3, Batch Loss: 0.1531742960214615
Epoch 1/3, Batch Loss: 0.2900690734386444
Epoch 1/3, Batch Loss: 0.07160258293151855
Epoch 1/3, Batch Loss: 0.032003406435251236
Epoch 1/3, Batch Loss: 0.2995308041572571
Epoch 1/3, Batch Loss: 0.19176769256591797
Epoch 1/3, Batch Loss: 0.39847245812416077
Epoch 1/3, Batch Loss: 0.0492219403386116
Epoch 1/3, Batch Loss: 0.17398422956466675
Epoch 1/3, Batch Loss: 0.13572709262371063
Epoch 1/3, Batch Loss: 0.12713825702667236
Epoch 1/3, Batch Loss: 0.22014327347278595
Epoch 1/3, Batch Loss: 0.5583798289299011
Epoch 1/3, Batch Loss: 0.0546497106552124
Epoch 1/3, Batch Loss: 0.20439910888671875
Epoch 1/3, Batch Loss: 0.24356591701507568
Epoch 1/3, Batch Loss: 0.0643753632903099
Epoch 1/3, Batch Loss: 0.3178096115589142
Epoch 1/3, Batch Loss: 0.14577288925647736
Epoch 1/3, Batch Loss: 0.16596296429634094
Epoch 1/3, Batch Loss: 0.20488418638706207
Epoch 1/3, Batch Loss: 0.09423104673624039
Epoch 1/3, Batch Loss: 0.08652830868959427
Epoch 1/3, Batch L

Epoch 1/3, Batch Loss: 0.34610921144485474
Epoch 1/3, Batch Loss: 0.29556795954704285
Epoch 1/3, Batch Loss: 0.4693663716316223
Epoch 1/3, Batch Loss: 0.03087153658270836
Epoch 1/3, Batch Loss: 0.29707369208335876
Epoch 1/3, Batch Loss: 0.09336701780557632
Epoch 1/3, Batch Loss: 0.2977447807788849
Epoch 1/3, Batch Loss: 0.2844799757003784
Epoch 1/3, Batch Loss: 0.16904765367507935
Epoch 1/3, Batch Loss: 0.1809263825416565
Epoch 1/3, Batch Loss: 0.09122513979673386
Epoch 1/3, Batch Loss: 0.14891384541988373
Epoch 1/3, Batch Loss: 0.17744840681552887
Epoch 1/3, Batch Loss: 0.03876589611172676
Epoch 1/3, Batch Loss: 0.09112913906574249
Epoch 1/3, Batch Loss: 0.14581064879894257
Epoch 1/3, Batch Loss: 0.047620125114917755
Epoch 1/3, Batch Loss: 0.16075921058654785
Epoch 1/3, Batch Loss: 0.03539394587278366
Epoch 1/3, Batch Loss: 0.03614191710948944
Epoch 1/3, Batch Loss: 0.06275549530982971
Epoch 1/3, Batch Loss: 0.3630457818508148
Epoch 1/3, Batch Loss: 0.556330144405365
Epoch 1/3, Batch 

Epoch 1/3, Batch Loss: 0.033481281250715256
Epoch 1/3, Batch Loss: 0.6270480155944824
Epoch 1/3, Batch Loss: 0.20397749543190002
Epoch 1/3, Batch Loss: 0.3104979693889618
Epoch 1/3, Batch Loss: 0.21079793572425842
Epoch 1/3, Batch Loss: 0.21174930036067963
Epoch 1/3, Batch Loss: 0.29921892285346985
Epoch 1/3, Batch Loss: 0.06456327438354492
Epoch 1/3, Batch Loss: 0.18274541199207306
Epoch 1/3, Batch Loss: 0.18125633895397186
Epoch 1/3, Batch Loss: 0.2743353247642517
Epoch 1/3, Batch Loss: 0.3485981225967407
Epoch 1/3, Batch Loss: 0.3391471803188324
Epoch 1/3, Batch Loss: 0.11286390572786331
Epoch 1/3, Batch Loss: 0.23879237473011017
Epoch 1/3, Batch Loss: 0.08285154402256012
Epoch 1/3, Batch Loss: 0.06419475376605988
Epoch 1/3, Batch Loss: 0.24947242438793182
Epoch 1/3, Batch Loss: 0.17841635644435883
Epoch 1/3, Batch Loss: 0.03197910636663437
Epoch 1/3, Batch Loss: 0.1418219655752182
Epoch 1/3, Batch Loss: 0.18125978112220764
Epoch 1/3, Batch Loss: 0.3686193823814392
Epoch 1/3, Batch 

Epoch 1/3, Batch Loss: 0.019078584387898445
Epoch 1/3, Batch Loss: 0.054134078323841095
Epoch 1/3, Batch Loss: 0.32029202580451965
Epoch 1/3, Batch Loss: 0.7002848982810974
Epoch 1/3, Batch Loss: 0.07041557878255844
Epoch 1/3, Batch Loss: 0.24875734746456146
Epoch 1/3, Batch Loss: 0.07645060867071152
Epoch 1/3, Batch Loss: 0.22575372457504272
Epoch 1/3, Batch Loss: 0.08133599162101746
Epoch 1/3, Batch Loss: 0.08561912924051285
Epoch 1/3, Batch Loss: 0.12999525666236877
Epoch 1/3, Batch Loss: 0.12184275686740875
Epoch 1/3, Batch Loss: 0.2451765388250351
Epoch 1/3, Batch Loss: 0.30078354477882385
Epoch 1/3, Batch Loss: 0.18925967812538147
Epoch 1/3, Batch Loss: 0.2590365409851074
Epoch 1/3, Batch Loss: 0.14376424252986908
Epoch 1/3, Batch Loss: 0.4029372036457062
Epoch 1/3, Batch Loss: 0.05297787860035896
Epoch 1/3, Batch Loss: 0.14574433863162994
Epoch 1/3, Batch Loss: 0.057210780680179596
Epoch 1/3, Batch Loss: 0.06979028135538101
Epoch 1/3, Batch Loss: 0.4207924008369446
Epoch 1/3, Ba

Epoch 1/3, Batch Loss: 0.03524075075984001
Epoch 1/3, Batch Loss: 0.10854899883270264
Epoch 1/3, Batch Loss: 0.18255165219306946
Epoch 1/3, Batch Loss: 0.018695583567023277
Epoch 1/3, Batch Loss: 0.20796695351600647
Epoch 1/3, Batch Loss: 0.05453336611390114
Epoch 1/3, Batch Loss: 0.1519196480512619
Epoch 1/3, Batch Loss: 0.18284429609775543
Epoch 1/3, Batch Loss: 0.1387006938457489
Epoch 1/3, Batch Loss: 0.04669569060206413
Epoch 1/3, Batch Loss: 0.23929643630981445
Epoch 1/3, Average Training Loss: 0.17345173280478968
Epoch 2/3, Batch Loss: 0.14802595973014832
Epoch 2/3, Batch Loss: 0.305084764957428
Epoch 2/3, Batch Loss: 0.15930484235286713
Epoch 2/3, Batch Loss: 0.06544852256774902
Epoch 2/3, Batch Loss: 0.1643039733171463
Epoch 2/3, Batch Loss: 0.2516372501850128
Epoch 2/3, Batch Loss: 0.15943270921707153
Epoch 2/3, Batch Loss: 0.01433564443141222
Epoch 2/3, Batch Loss: 0.08435852080583572
Epoch 2/3, Batch Loss: 0.13077236711978912
Epoch 2/3, Batch Loss: 0.09956496208906174
Epoch

Epoch 2/3, Batch Loss: 0.18349674344062805
Epoch 2/3, Batch Loss: 0.034297406673431396
Epoch 2/3, Batch Loss: 0.29875412583351135
Epoch 2/3, Batch Loss: 0.1718786209821701
Epoch 2/3, Batch Loss: 0.24328982830047607
Epoch 2/3, Batch Loss: 0.13630129396915436
Epoch 2/3, Batch Loss: 0.24961325526237488
Epoch 2/3, Batch Loss: 0.13585586845874786
Epoch 2/3, Batch Loss: 0.02710254117846489
Epoch 2/3, Batch Loss: 0.14731168746948242
Epoch 2/3, Batch Loss: 0.25004950165748596
Epoch 2/3, Batch Loss: 0.25171396136283875
Epoch 2/3, Batch Loss: 0.033792901784181595
Epoch 2/3, Batch Loss: 0.2670421898365021
Epoch 2/3, Batch Loss: 0.11865291744470596
Epoch 2/3, Batch Loss: 0.038662489503622055
Epoch 2/3, Batch Loss: 0.18801763653755188
Epoch 2/3, Batch Loss: 0.2990275025367737
Epoch 2/3, Batch Loss: 0.17290467023849487
Epoch 2/3, Batch Loss: 0.06509263068437576
Epoch 2/3, Batch Loss: 0.0681196004152298
Epoch 2/3, Batch Loss: 0.07436364889144897
Epoch 2/3, Batch Loss: 0.24080048501491547
Epoch 2/3, B

Epoch 2/3, Batch Loss: 0.17939536273479462
Epoch 2/3, Batch Loss: 0.17653527855873108
Epoch 2/3, Batch Loss: 0.25791677832603455
Epoch 2/3, Batch Loss: 0.5909048318862915
Epoch 2/3, Batch Loss: 0.2734295725822449
Epoch 2/3, Batch Loss: 0.3214224874973297
Epoch 2/3, Batch Loss: 0.2543002963066101
Epoch 2/3, Batch Loss: 0.12669417262077332
Epoch 2/3, Batch Loss: 0.29204753041267395
Epoch 2/3, Batch Loss: 0.05416342616081238
Epoch 2/3, Batch Loss: 0.06918105483055115
Epoch 2/3, Batch Loss: 0.1945546269416809
Epoch 2/3, Batch Loss: 0.32174739241600037
Epoch 2/3, Batch Loss: 0.16468843817710876
Epoch 2/3, Batch Loss: 0.02701990120112896
Epoch 2/3, Batch Loss: 0.28120970726013184
Epoch 2/3, Batch Loss: 0.08281724154949188
Epoch 2/3, Batch Loss: 0.3198547065258026
Epoch 2/3, Batch Loss: 0.1485927850008011
Epoch 2/3, Batch Loss: 0.0705997496843338
Epoch 2/3, Batch Loss: 0.15440939366817474
Epoch 2/3, Batch Loss: 0.17463786900043488
Epoch 2/3, Batch Loss: 0.14741823077201843
Epoch 2/3, Batch Lo

Epoch 2/3, Batch Loss: 0.046449657529592514
Epoch 2/3, Batch Loss: 0.16750559210777283
Epoch 2/3, Batch Loss: 0.044119175523519516
Epoch 2/3, Batch Loss: 0.32692763209342957
Epoch 2/3, Batch Loss: 0.19305162131786346
Epoch 2/3, Batch Loss: 0.05584177002310753
Epoch 2/3, Batch Loss: 0.10875163227319717
Epoch 2/3, Batch Loss: 0.15079320967197418
Epoch 2/3, Batch Loss: 0.1637493222951889
Epoch 2/3, Batch Loss: 0.0709027424454689
Epoch 2/3, Batch Loss: 0.05581067129969597
Epoch 2/3, Batch Loss: 0.030623286962509155
Epoch 2/3, Batch Loss: 0.29525747895240784
Epoch 2/3, Batch Loss: 0.31963828206062317
Epoch 2/3, Batch Loss: 0.07960972934961319
Epoch 2/3, Batch Loss: 0.05906660854816437
Epoch 2/3, Batch Loss: 0.16815973818302155
Epoch 2/3, Batch Loss: 0.02282172068953514
Epoch 2/3, Batch Loss: 0.12795479595661163
Epoch 2/3, Batch Loss: 0.0312616266310215
Epoch 2/3, Batch Loss: 0.0215320847928524
Epoch 2/3, Batch Loss: 0.10381065309047699
Epoch 2/3, Batch Loss: 0.04812092334032059
Epoch 2/3, B

Epoch 2/3, Batch Loss: 0.29326313734054565
Epoch 2/3, Batch Loss: 0.17147871851921082
Epoch 2/3, Batch Loss: 0.0875081717967987
Epoch 2/3, Batch Loss: 0.3356613516807556
Epoch 2/3, Batch Loss: 0.07716593891382217
Epoch 2/3, Batch Loss: 0.48294979333877563
Epoch 2/3, Batch Loss: 0.03858717530965805
Epoch 2/3, Batch Loss: 0.11943954229354858
Epoch 2/3, Batch Loss: 0.16430576145648956
Epoch 2/3, Batch Loss: 0.06717605888843536
Epoch 2/3, Batch Loss: 0.22479599714279175
Epoch 2/3, Batch Loss: 0.25634270906448364
Epoch 2/3, Batch Loss: 0.27033495903015137
Epoch 2/3, Batch Loss: 0.3680388033390045
Epoch 2/3, Batch Loss: 0.03850107640028
Epoch 2/3, Batch Loss: 0.0720294788479805
Epoch 2/3, Batch Loss: 0.15098603069782257
Epoch 2/3, Batch Loss: 0.6033148169517517
Epoch 2/3, Batch Loss: 0.05773773044347763
Epoch 2/3, Batch Loss: 0.09043504297733307
Epoch 2/3, Batch Loss: 0.2577856779098511
Epoch 2/3, Batch Loss: 0.13413789868354797
Epoch 2/3, Batch Loss: 0.1076153814792633
Epoch 2/3, Batch Loss

Epoch 2/3, Batch Loss: 0.21957653760910034
Epoch 2/3, Batch Loss: 0.1247650533914566
Epoch 2/3, Batch Loss: 0.05777272582054138
Epoch 2/3, Batch Loss: 0.15102870762348175
Epoch 2/3, Batch Loss: 0.5010320544242859
Epoch 2/3, Batch Loss: 0.1398581713438034
Epoch 2/3, Batch Loss: 0.6087892055511475
Epoch 2/3, Batch Loss: 0.11649198830127716
Epoch 2/3, Batch Loss: 0.21961677074432373
Epoch 2/3, Batch Loss: 0.19740663468837738
Epoch 2/3, Batch Loss: 0.2633785307407379
Epoch 2/3, Batch Loss: 0.051677048206329346
Epoch 2/3, Batch Loss: 0.2027502954006195
Epoch 2/3, Batch Loss: 0.1465640515089035
Epoch 2/3, Batch Loss: 0.20581495761871338
Epoch 2/3, Batch Loss: 0.4071029722690582
Epoch 2/3, Batch Loss: 0.23273661732673645
Epoch 2/3, Batch Loss: 0.17979013919830322
Epoch 2/3, Batch Loss: 0.07140768319368362
Epoch 2/3, Batch Loss: 0.10083869099617004
Epoch 2/3, Batch Loss: 0.2346857190132141
Epoch 2/3, Average Training Loss: 0.17056572129211392
Epoch 3/3, Batch Loss: 0.1342838555574417
Epoch 3/3

Epoch 3/3, Batch Loss: 0.10885586589574814
Epoch 3/3, Batch Loss: 0.060866374522447586
Epoch 3/3, Batch Loss: 0.15056972205638885
Epoch 3/3, Batch Loss: 0.1282985806465149
Epoch 3/3, Batch Loss: 0.039345644414424896
Epoch 3/3, Batch Loss: 0.05962114408612251
Epoch 3/3, Batch Loss: 0.05160920321941376
Epoch 3/3, Batch Loss: 0.05373261496424675
Epoch 3/3, Batch Loss: 0.31569141149520874
Epoch 3/3, Batch Loss: 0.10887966305017471
Epoch 3/3, Batch Loss: 0.14469477534294128
Epoch 3/3, Batch Loss: 0.17081418633460999
Epoch 3/3, Batch Loss: 0.2547680735588074
Epoch 3/3, Batch Loss: 0.10032179206609726
Epoch 3/3, Batch Loss: 0.16850979626178741
Epoch 3/3, Batch Loss: 0.013749380595982075
Epoch 3/3, Batch Loss: 0.5373250246047974
Epoch 3/3, Batch Loss: 0.03562391549348831
Epoch 3/3, Batch Loss: 0.17635464668273926
Epoch 3/3, Batch Loss: 0.20212452113628387
Epoch 3/3, Batch Loss: 0.09952102601528168
Epoch 3/3, Batch Loss: 0.2747792899608612
Epoch 3/3, Batch Loss: 0.21777166426181793
Epoch 3/3, B

Epoch 3/3, Batch Loss: 0.19876965880393982
Epoch 3/3, Batch Loss: 0.14148783683776855
Epoch 3/3, Batch Loss: 0.19141295552253723
Epoch 3/3, Batch Loss: 0.027488749474287033
Epoch 3/3, Batch Loss: 0.14351306855678558
Epoch 3/3, Batch Loss: 0.059201184660196304
Epoch 3/3, Batch Loss: 0.25417789816856384
Epoch 3/3, Batch Loss: 0.2051084041595459
Epoch 3/3, Batch Loss: 0.43192559480667114
Epoch 3/3, Batch Loss: 0.15885265171527863
Epoch 3/3, Batch Loss: 0.40538695454597473
Epoch 3/3, Batch Loss: 0.15903152525424957
Epoch 3/3, Batch Loss: 0.24459466338157654
Epoch 3/3, Batch Loss: 0.2789527177810669
Epoch 3/3, Batch Loss: 0.1277133971452713
Epoch 3/3, Batch Loss: 0.12572315335273743
Epoch 3/3, Batch Loss: 0.04382966831326485
Epoch 3/3, Batch Loss: 0.04408558830618858
Epoch 3/3, Batch Loss: 0.09213494509458542
Epoch 3/3, Batch Loss: 0.3205047845840454
Epoch 3/3, Batch Loss: 0.2658361494541168
Epoch 3/3, Batch Loss: 0.2882544994354248
Epoch 3/3, Batch Loss: 0.09137771278619766
Epoch 3/3, Batc

Epoch 3/3, Batch Loss: 0.03962955251336098
Epoch 3/3, Batch Loss: 0.17252054810523987
Epoch 3/3, Batch Loss: 0.06631484627723694
Epoch 3/3, Batch Loss: 0.09252310544252396
Epoch 3/3, Batch Loss: 0.11239134520292282
Epoch 3/3, Batch Loss: 0.03258439153432846
Epoch 3/3, Batch Loss: 0.6044567227363586
Epoch 3/3, Batch Loss: 0.054372262209653854
Epoch 3/3, Batch Loss: 0.4064238369464874
Epoch 3/3, Batch Loss: 0.15903066098690033
Epoch 3/3, Batch Loss: 0.029791681095957756
Epoch 3/3, Batch Loss: 0.3152630031108856
Epoch 3/3, Batch Loss: 0.19722042977809906
Epoch 3/3, Batch Loss: 0.13422074913978577
Epoch 3/3, Batch Loss: 0.21994782984256744
Epoch 3/3, Batch Loss: 0.16267144680023193
Epoch 3/3, Batch Loss: 0.35829776525497437
Epoch 3/3, Batch Loss: 0.18595057725906372
Epoch 3/3, Batch Loss: 0.4097293019294739
Epoch 3/3, Batch Loss: 0.14947469532489777
Epoch 3/3, Batch Loss: 0.0685151219367981
Epoch 3/3, Batch Loss: 0.0975702628493309
Epoch 3/3, Batch Loss: 0.2135740965604782
Epoch 3/3, Batch

Epoch 3/3, Batch Loss: 0.06666618585586548
Epoch 3/3, Batch Loss: 0.16279500722885132
Epoch 3/3, Batch Loss: 0.22288434207439423
Epoch 3/3, Batch Loss: 0.10658922046422958
Epoch 3/3, Batch Loss: 0.1609991192817688
Epoch 3/3, Batch Loss: 0.392749160528183
Epoch 3/3, Batch Loss: 0.034985680133104324
Epoch 3/3, Batch Loss: 0.6328175663948059
Epoch 3/3, Batch Loss: 0.25472521781921387
Epoch 3/3, Batch Loss: 0.23559163510799408
Epoch 3/3, Batch Loss: 0.0516403503715992
Epoch 3/3, Batch Loss: 0.09910814464092255
Epoch 3/3, Batch Loss: 0.15503990650177002
Epoch 3/3, Batch Loss: 0.3975571393966675
Epoch 3/3, Batch Loss: 0.08343853056430817
Epoch 3/3, Batch Loss: 0.05380073934793472
Epoch 3/3, Batch Loss: 0.0813971757888794
Epoch 3/3, Batch Loss: 0.1504029631614685
Epoch 3/3, Batch Loss: 0.0378723219037056
Epoch 3/3, Batch Loss: 0.08079288899898529
Epoch 3/3, Batch Loss: 0.31293126940727234
Epoch 3/3, Batch Loss: 0.08859486132860184
Epoch 3/3, Batch Loss: 0.13454806804656982
Epoch 3/3, Batch Lo

Epoch 3/3, Batch Loss: 0.34209775924682617
Epoch 3/3, Batch Loss: 0.11648349463939667
Epoch 3/3, Batch Loss: 0.3485059440135956
Epoch 3/3, Batch Loss: 0.05554275959730148
Epoch 3/3, Batch Loss: 0.11406367272138596
Epoch 3/3, Batch Loss: 0.17429490387439728
Epoch 3/3, Batch Loss: 0.11334726214408875
Epoch 3/3, Batch Loss: 0.06822095066308975
Epoch 3/3, Batch Loss: 0.16231417655944824
Epoch 3/3, Batch Loss: 0.3176786005496979
Epoch 3/3, Batch Loss: 0.16107504069805145
Epoch 3/3, Batch Loss: 0.12196222692728043
Epoch 3/3, Batch Loss: 0.2595650255680084
Epoch 3/3, Batch Loss: 0.3519720137119293
Epoch 3/3, Batch Loss: 0.4072616994380951
Epoch 3/3, Batch Loss: 0.1399233639240265
Epoch 3/3, Batch Loss: 0.07074470818042755
Epoch 3/3, Batch Loss: 0.08141618221998215
Epoch 3/3, Batch Loss: 0.25972944498062134
Epoch 3/3, Batch Loss: 0.07537350058555603
Epoch 3/3, Batch Loss: 0.06127890571951866
Epoch 3/3, Batch Loss: 0.20548784732818604
Epoch 3/3, Batch Loss: 0.05908646807074547
Epoch 3/3, Batch 

In [15]:
# 데이터 랜덤분할(100/500)
import pandas as pd

def sample_csv_and_additional(input_file, output_file_500, output_file_100, n_500):
    # CSV 파일을 읽어옵니다.
    data = pd.read_csv(input_file)
    
    # 데이터를 랜덤하게 샘플링합니다.
    sampled_data_500 = data.sample(n=n_500, random_state=42)
    
    # 샘플링된 500개의 데이터를 CSV 파일로 내보냅니다.
    sampled_data_500.to_csv(output_file_500, index=False)
    
    # sampled_data_500에서 첫 100개의 데이터를 선택합니다.
    sampled_data_100 = sampled_data_500.head(100)
    
    # 선택된 첫 100개의 데이터를 CSV 파일로 내보냅니다.
    sampled_data_100.to_csv(output_file_100, index=False)

# 입력 CSV 파일 경로
input_file = "cleaned_covid_merged.csv"

# 출력 CSV 파일 경로
output_file_500 = "random_500.csv"
output_file_100 = "random_100.csv"

# 랜덤하게 추출할 데이터 개수
n_500 = 500

# 함수 호출
sample_csv_and_additional(input_file, output_file_500, output_file_100, n_500)


In [16]:
# smashed data 생성 (500/server side)
import os
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, TensorDataset
import torch
from sklearn.model_selection import train_test_split

class CustomBertForSequenceClassification(BertForSequenceClassification):
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        labels=None,
        output_hidden_states=True
    ):
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            labels=labels,
            output_hidden_states=output_hidden_states
        )
        logits = outputs.logits
        hidden_states = outputs.hidden_states[-5]  # n번째 레이어의 hidden states를 반환합니다.
        loss = outputs.loss
        return logits, loss, hidden_states

# 데이터 로드 및 전처리
data_A = pd.read_csv("random_500.csv")  # data set A 파일명에 맞게 수정
data_B = pd.read_csv("infected.csv")  # data set B 파일명에 맞게 수정
# 모델 저장 경로
model_path = "Pre-trained.pt"

# X_train, Y_train 생성
X_train = []
Y_train = []

for index, row in data_A.iterrows():  # 중복 제거를 하지 않고 원본 데이터 사용
    patient_id = row["ID"]
    patient_info = [str(row[column]) for column in data_A.columns if column != "ID" and column != "DESCRIPTION"]
    symptoms = ", ".join(data_A[data_A["ID"] == patient_id]["DESCRIPTION"].tolist())
    combined_info = ", ".join(patient_info) + ", " + symptoms
    X_train.append(combined_info)
    if patient_id in data_B.values:
        Y_train.append(1)
    else:
        Y_train.append(0)

print("X_train\n", X_train[:10])
print("Y_train\n", Y_train[:10])
        
# BERT 토크나이저 및 모델 로드
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# 모델이 이미 저장되어 있는지 확인하고, 저장된 모델이 있으면 불러오고 없으면 새로운 모델 생성
if os.path.exists(model_path):
    # 저장된 모델이 있을 경우 불러오기
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    model.load_state_dict(torch.load(model_path))
    print("Pre-train model loaded.")
else:
    # 저장된 모델이 없을 경우 새로운 모델 생성
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    print("New model generated.")

# 입력 데이터를 BERT의 입력 형식으로 변환
max_len = 128  # 입력 시퀀스의 최대 길이

input_ids = []
attention_masks = []

for info in X_train:
    encoded_dict = tokenizer.encode_plus(
                        info,                         # 환자 정보 및 증상
                        add_special_tokens = True,    # [CLS], [SEP] 토큰 추가
                        max_length = max_len,         # 최대 길이 지정
                        pad_to_max_length = True,     # 패딩을 추가하여 최대 길이로 맞춤
                        return_attention_mask = True, # 어텐션 마스크 생성
                        return_tensors = 'pt',        # PyTorch 텐서로 반환
                   )
    
    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(Y_train)

# 데이터셋 생성
dataset = TensorDataset(input_ids, attention_masks, labels)

# 데이터로더 생성
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# GPU 사용 가능 여부 확인
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

# 모델을 GPU로 이동
model.to(device)

# 모델 평가
model.eval()
val_accuracy = 0
hidden_states_list = []  # 평가할 때 hidden state를 저장할 리스트
for batch in dataloader:
    batch = tuple(t.to(device) for t in batch)
    inputs = {'input_ids': batch[0],
              'attention_mask': batch[1],
              'labels': batch[2]}
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs[0]  # logits가 outputs의 첫 번째 값입니다.
    logits = logits.detach().cpu().numpy()
    label_ids = inputs['labels'].cpu().numpy()
    val_accuracy += (logits.argmax(axis=1) == label_ids).mean().item()
    # hidden state를 저장합니다.
    hidden_states = outputs[2]
    hidden_states_list.append(hidden_states)
hidden_states_concat = torch.cat(hidden_states_list, dim=0)
hidden_states_concat = hidden_states_concat[:, 0, :].cpu().detach().numpy()
hidden_states_df = pd.DataFrame(hidden_states_concat)
hidden_states_df.to_csv("Dictionary_smashed_data.csv", index=False)

print(f'Validation Accuracy: {val_accuracy / len(dataloader)}')


X_train
 ["5/11/1967, nan, 999-83-3739, S99963041, X22481021X, Mrs., Hortense60, O'Hara248, nan, Spinka232, M, white, nonhispanic, F, Northborough  Massachusetts  US, 1040 Turner Knoll, Milford, Massachusetts, Worcester County, nan, 42.13519473, -71.50138606, 1137082.44, 9598.16, 11/12/2019, 1/21/2020, 803d9786-29a1-466c-9365-0205ea0a031c, 36971009.0, Sinusitis (disorder)", '11/19/1965, nan, 999-87-9895, S99998149, X88597197X, Mr., Waylon572, Reinger292, nan, nan, M, white, nonhispanic, M, Seekonk  Massachusetts  US, 895 Robel Light, Worcester, Massachusetts, Worcester County, 1604.0, 42.32627029, -71.79641283, 1134089.03, 5790.2, 5/7/2004, nan, ef7b5155-d4be-4b0f-98d4-1313f9b7e90b, 162864005.0, Body mass index 30+ - obesity (finding)', '1/20/1912, nan, 999-73-8598, S99979217, X22749225X, Mr., Hans694, Yost751, nan, nan, M, white, nonhispanic, M, Chelsea  Massachusetts  US, 158 Frami Drive Apt 55, Chelmsford, Massachusetts, Middlesex County, nan, 42.62575626, -71.41794612, 372357.04, 6

Some weights of CustomBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Pre-train model loaded.




True
Validation Accuracy: 0.939453125


In [17]:
# smashed data 생성 (100/client side)
import os
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, TensorDataset
import torch
from sklearn.model_selection import train_test_split

class CustomBertForSequenceClassification(BertForSequenceClassification):
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        labels=None,
        output_hidden_states=True
    ):
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            labels=labels,
            output_hidden_states=output_hidden_states
        )
        logits = outputs.logits
        hidden_states = outputs.hidden_states[-5]  # n번째 레이어의 hidden states를 반환합니다.
        loss = outputs.loss
        return logits, loss, hidden_states

# 데이터 로드 및 전처리
data_A = pd.read_csv("random_100.csv")  # data set A 파일명에 맞게 수정
data_B = pd.read_csv("infected.csv")  # data set B 파일명에 맞게 수정
# 모델 저장 경로
model_path = "Fine-tuned.pt"

# X_train, Y_train 생성
X_train = []
Y_train = []

for index, row in data_A.iterrows():  # 중복 제거를 하지 않고 원본 데이터 사용
    patient_id = row["ID"]
    patient_info = [str(row[column]) for column in data_A.columns if column != "ID" and column != "DESCRIPTION"]
    symptoms = ", ".join(data_A[data_A["ID"] == patient_id]["DESCRIPTION"].tolist())
    combined_info = ", ".join(patient_info) + ", " + symptoms
    X_train.append(combined_info)
    if patient_id in data_B.values:
        Y_train.append(1)
    else:
        Y_train.append(0)

print("X_train\n", X_train[:10])
print("Y_train\n", Y_train[:10])
        
# BERT 토크나이저 및 모델 로드
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# 모델이 이미 저장되어 있는지 확인하고, 저장된 모델이 있으면 불러오고 없으면 새로운 모델 생성
if os.path.exists(model_path):
    # 저장된 모델이 있을 경우 불러오기
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    model.load_state_dict(torch.load(model_path))
    print("Pre-train model loaded.")
else:
    # 저장된 모델이 없을 경우 새로운 모델 생성
    model = CustomBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    print("New model generated.")

# 입력 데이터를 BERT의 입력 형식으로 변환
max_len = 128  # 입력 시퀀스의 최대 길이

input_ids = []
attention_masks = []

for info in X_train:
    encoded_dict = tokenizer.encode_plus(
                        info,                         # 환자 정보 및 증상
                        add_special_tokens = True,    # [CLS], [SEP] 토큰 추가
                        max_length = max_len,         # 최대 길이 지정
                        pad_to_max_length = True,     # 패딩을 추가하여 최대 길이로 맞춤
                        return_attention_mask = True, # 어텐션 마스크 생성
                        return_tensors = 'pt',        # PyTorch 텐서로 반환
                   )
    
    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(Y_train)

# 데이터셋 생성
dataset = TensorDataset(input_ids, attention_masks, labels)

# 데이터로더 생성
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# GPU 사용 가능 여부 확인
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

# 모델을 GPU로 이동
model.to(device)

# 모델 평가
model.eval()
val_accuracy = 0
hidden_states_list = []  # 평가할 때 hidden state를 저장할 리스트
for batch in dataloader:
    batch = tuple(t.to(device) for t in batch)
    inputs = {'input_ids': batch[0],
              'attention_mask': batch[1],
              'labels': batch[2]}
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs[0]  # logits가 outputs의 첫 번째 값입니다.
    logits = logits.detach().cpu().numpy()
    label_ids = inputs['labels'].cpu().numpy()
    val_accuracy += (logits.argmax(axis=1) == label_ids).mean().item()
    # hidden state를 저장합니다.
    hidden_states = outputs[2]
    hidden_states_list.append(hidden_states)
hidden_states_concat = torch.cat(hidden_states_list, dim=0)
hidden_states_concat = hidden_states_concat[:, 0, :].cpu().detach().numpy()
hidden_states_df = pd.DataFrame(hidden_states_concat)
hidden_states_df.to_csv("Client_smashed_data.csv", index=False)

print(f'Validation Accuracy: {val_accuracy / len(dataloader)}')


X_train
 ["5/11/1967, nan, 999-83-3739, S99963041, X22481021X, Mrs., Hortense60, O'Hara248, nan, Spinka232, M, white, nonhispanic, F, Northborough  Massachusetts  US, 1040 Turner Knoll, Milford, Massachusetts, Worcester County, nan, 42.13519473, -71.50138606, 1137082.44, 9598.16, 11/12/2019, 1/21/2020, 803d9786-29a1-466c-9365-0205ea0a031c, 36971009.0, Sinusitis (disorder)", '11/19/1965, nan, 999-87-9895, S99998149, X88597197X, Mr., Waylon572, Reinger292, nan, nan, M, white, nonhispanic, M, Seekonk  Massachusetts  US, 895 Robel Light, Worcester, Massachusetts, Worcester County, 1604.0, 42.32627029, -71.79641283, 1134089.03, 5790.2, 5/7/2004, nan, ef7b5155-d4be-4b0f-98d4-1313f9b7e90b, 162864005.0, Body mass index 30+ - obesity (finding)', '1/20/1912, nan, 999-73-8598, S99979217, X22749225X, Mr., Hans694, Yost751, nan, nan, M, white, nonhispanic, M, Chelsea  Massachusetts  US, 158 Frami Drive Apt 55, Chelmsford, Massachusetts, Middlesex County, nan, 42.62575626, -71.41794612, 372357.04, 6

Some weights of CustomBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Pre-train model loaded.
True
Validation Accuracy: 0.9107142857142857


In [25]:
# Top N NN산출
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

def read_csv_data(file_paths):
    data = []
    for file_path in file_paths:
        df = pd.read_csv(file_path, header=None)
        data.append(df.values)
    return data

def calculate_similarity(client_data, server_data):
    similarity_scores = cosine_similarity(client_data, server_data)
    return similarity_scores

def get_top_n_inferences(similarity_scores, n):
    top_n_indices = np.argsort(similarity_scores, axis=1)[:, -n:]
    return top_n_indices

def main(client_files, server_files, n):
    # Read CSV data
    client_data = read_csv_data(client_files)
    server_data = read_csv_data(server_files)

    # Exclude the first row from server data for comparison
    server_data = [data[1:] for data in server_data]

    # Convert data to numpy arrays
    client_data = np.vstack(client_data)
    server_data = np.vstack(server_data)

    # Calculate similarity
    similarity_scores = calculate_similarity(client_data, server_data)

    # Get top N inferences
    top_n_indices = get_top_n_inferences(similarity_scores, n)

    # Output top N inferences
    for i, indices in enumerate(top_n_indices):
        print(f"Top {n} inferences for client {i}:")
        for idx in indices:
            print(f"Server {idx} with similarity score {similarity_scores[i][idx]}")

if __name__ == "__main__":
    client_files = ['Client_smashed_data.csv']
    server_files = ['Dictionary_smashed_data.csv']
    n = 5  # Number of top inferences
    main(client_files, server_files, n)


Top 5 inferences for client 0:
Server 69 with similarity score -0.027380570578298798
Server 141 with similarity score -0.027331109428647766
Server 289 with similarity score -0.02709142236258632
Server 151 with similarity score -0.025491445245368997
Server 247 with similarity score -0.02339054532210398
Top 5 inferences for client 1:
Server 431 with similarity score 0.9670944617132706
Server 358 with similarity score 0.9694984380883578
Server 153 with similarity score 0.9712645401455625
Server 195 with similarity score 0.9721232025845163
Server 279 with similarity score 0.9853076338775929
Top 5 inferences for client 2:
Server 452 with similarity score 0.9693298776947514
Server 87 with similarity score 0.96933228996485
Server 8 with similarity score 0.9695756022357571
Server 432 with similarity score 0.9696847321832198
Server 269 with similarity score 0.9757509783407778
Top 5 inferences for client 3:
Server 413 with similarity score 0.9721511980602019
Server 207 with similarity score 0.97

In [28]:
# Top@N Accuracy 산출
import pandas as pd
import numpy as np

def euclidean_distance(vector1, vector2):
    return np.linalg.norm(vector1 - vector2)

def calculate_top_n_accuracy(client_data, dictionary_data, n):
    success_count = 0
    success_indices = []
    for i in range(len(client_data)):
        client_vector = np.array(client_data.iloc[i])
        distances = []
        for j in range(1, len(dictionary_data)):  # 인덱스 0을 제외하고 비교
            dictionary_vector = np.array(dictionary_data.iloc[j])
            distance = euclidean_distance(client_vector, dictionary_vector)
            distances.append((distance, j))  # (distance, index)
        distances.sort()  # Sort distances in ascending order
        top_n_indices = [index for _, index in distances[:n]]  # Get top N indices
        if i+1 in top_n_indices:  # client 데이터 인덱스를 1부터 시작하므로 i+1 사용
            success_count += 1
            success_indices.append(i+1)  # client 데이터 인덱스를 1부터 시작하므로 i+1 사용
    return success_count / len(client_data), success_indices

# 두 암호화된 파일 로드
client_data = pd.read_csv("Client_smashed_data.csv", header=None)
dictionary_data = pd.read_csv("Dictionary_smashed_data.csv", header=None)

# 상위 N개 설정
top_n = 5

# Top@N 정확도 계산
accuracy, success_indices = calculate_top_n_accuracy(client_data, dictionary_data, top_n)
print(f"Top@{top_n} Accuracy: {accuracy}")
print(f"Success Indices: {success_indices}")


Top@5 Accuracy: 0.009900990099009901
Success Indices: [85]
