In [40]:
from torch.utils.data import Dataset
from torch import nn
import torch
from torch.utils.data import DataLoader
#from q71 import SLPNet
from q73 import NewsDataset
from q74 import load_Dataloader
from q78 import calculate_loss_and_accuracy
import time
from torch.nn import functional as F

# class SLPNet(nn.Module):
#     def __init__(self, input_size, output_size,mid_size):        
#         super().__init__()
#         self.fc = nn.Linear(input_size,mid_size)        
#         self.fc2= nn.Linear(mid_size, mid_size)
#         self.fc3=nn.Linear(mid_size, output_size)
#         #正則化とドロップアウトを追加
#         self.dropout = nn.Dropout(0.5)
#         self.bn = nn.BatchNorm1d(mid_size)
# #        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # 正規乱数で重みを初期化
        
#     def forward(self, x):
#         x = F.relu(self.fc(x))
#         x = F.relu(self.fc2(x))
#         x = self.dropout(x)
#         x = F.relu(self.bn(x))
#         x = self.dropout(x)
#         x = F.relu(self.fc3(x))        
# #        x = self.relu(x)
# #        x = self.dropout(x)
        
#         return x
    
class MLPNet(nn.Module):
    def __init__(self, input_size, mid_size, output_size, mid_layers):
        super().__init__()
        self.mid_layers = mid_layers
        self.fc = nn.Linear(input_size, mid_size)
        self.fc_mid = nn.Linear(mid_size, mid_size)
        self.fc_out = nn.Linear(mid_size, output_size) 
        self.bn = nn.BatchNorm1d(mid_size)
        
    def forward(self, x):
        x = F.relu(self.fc(x))
        
        for _ in range(self.mid_layers):
            x = F.relu(self.bn(self.fc_mid(x)))
            x = F.relu(self.fc_out(x))
            
        return x    

def train_model(dataloader_train, dataloader_valid,model, criterion, optimizer, num_epochs, device=None):
    # GPUに送る
    model.to(device)

    
    # 学習
    log_train = []
    log_valid = []
    
    scheduler =torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, eta_min=1e-5)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer,num_epochs/5,0.5)
    for epoch in range(num_epochs):
        # 開始時刻の記録
        s_time = time.time()
        
        # 訓練モードに設定 
        model.train()
        
        for inputs, labels in dataloader_train:
            # 勾配をゼロで初期化
            optimizer.zero_grad()
            # 順伝播 + 誤差逆伝播 + 重み更新
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model.forward(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
        # 損失と正解率の算出
        loss_train, acc_train = calculate_loss_and_accuracy(model, criterion, dataloader_train, device)
        loss_valid, acc_valid = calculate_loss_and_accuracy(model, criterion, dataloader_valid, device)
        log_train.append([loss_train, acc_train])
        log_valid.append([loss_valid, acc_valid])
            
        # チェックポイントの保存
#        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, f'checkpoint{epoch + 1}.pt')
            
        # 終了時刻の記録
        e_time = time.time()
        
        # ログを出力
        print(f'epoch: {epoch + 1}, loss_train: {loss_train:.4f}, accuracy_train: {acc_train:.4f}, loss_valid: {loss_valid:.4f}, accuracy_valid: {acc_valid:.4f}, {(e_time - s_time):.4f}sec') 
        
        if epoch > 2 and log_valid[epoch - 3][0] <= log_valid[epoch - 2][0] <= log_valid[epoch - 1][0] <= log_valid[epoch][0]:
            break
        scheduler.step()
    return {'train': log_train, 'valid': log_valid}

if __name__=="__main__":

    
#    model = SLPNet(300, 4,200)# モデルの定義    
    model = MLPNet(300, 200, 4, 1)
    criterion = nn.CrossEntropyLoss()# 損失関数の定義
    #optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)# オプティマイザの定義
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    device = torch.device('cuda')# デバイスの指定
    
    # モデルの学習
    dataloader_train = load_Dataloader("train",64)
#    dataloader_valid = load_Dataloader("valid", 64)
#    dataloader_test = load_Dataloader("test", 64)

#     X_train=torch.load("X_train2.pt")
#     Y_train=torch.load("Y_train2.pt")
#     dataset_train=NewsDataset(X_train,Y_train)
#     dataloader_train=DataLoader(dataset_train, 64, shuffle=True)
    
    X_valid=torch.load("X_valid2.pt")
    Y_valid=torch.load("Y_valid2.pt")
    dataset_valid=NewsDataset(X_valid,Y_valid)
    dataloader_valid=DataLoader(dataset_valid, len(dataset_valid), shuffle=False)
    
    X_test=torch.load("X_test2.pt")
    Y_test=torch.load("Y_test2.pt")
    dataset_test=NewsDataset(X_test,Y_test)
    dataloader_test=DataLoader(dataset_test,len(dataset_test), shuffle=False)
    
    
    log = train_model(dataloader_train, dataloader_valid,model, criterion, optimizer, 1000,device=device)
    
    loss_train, acc_train = calculate_loss_and_accuracy(model, criterion, dataloader_train, device)
    loss_test, acc_test = calculate_loss_and_accuracy(model, criterion, dataloader_test, device)
    print(f"学習データでの正解率:{acc_train}")
    print(f"評価データでの正解率:{acc_test}")

epoch: 1, loss_train: 0.9033, accuracy_train: 0.7393, loss_valid: 0.9031, accuracy_valid: 0.7369, 0.3723sec
epoch: 2, loss_train: 0.7501, accuracy_train: 0.7581, loss_valid: 0.7507, accuracy_valid: 0.7571, 0.3680sec
epoch: 3, loss_train: 0.6833, accuracy_train: 0.7688, loss_valid: 0.6839, accuracy_valid: 0.7691, 0.3858sec
epoch: 4, loss_train: 0.6409, accuracy_train: 0.7774, loss_valid: 0.6417, accuracy_valid: 0.7796, 0.3852sec
epoch: 5, loss_train: 0.6045, accuracy_train: 0.7900, loss_valid: 0.6054, accuracy_valid: 0.7916, 0.3810sec
epoch: 6, loss_train: 0.5722, accuracy_train: 0.8027, loss_valid: 0.5732, accuracy_valid: 0.8051, 0.3695sec
epoch: 7, loss_train: 0.5436, accuracy_train: 0.8145, loss_valid: 0.5456, accuracy_valid: 0.8133, 0.3943sec
epoch: 8, loss_train: 0.5181, accuracy_train: 0.8264, loss_valid: 0.5207, accuracy_valid: 0.8276, 0.3780sec
epoch: 9, loss_train: 0.4961, accuracy_train: 0.8384, loss_valid: 0.5007, accuracy_valid: 0.8366, 0.3638sec
epoch: 10, loss_train: 0.477

epoch: 77, loss_train: 0.2320, accuracy_train: 0.9230, loss_valid: 0.3246, accuracy_valid: 0.8898, 0.3700sec
epoch: 78, loss_train: 0.2269, accuracy_train: 0.9261, loss_valid: 0.3244, accuracy_valid: 0.8928, 0.3700sec
epoch: 79, loss_train: 0.2264, accuracy_train: 0.9264, loss_valid: 0.3257, accuracy_valid: 0.8943, 0.3723sec
epoch: 80, loss_train: 0.2237, accuracy_train: 0.9256, loss_valid: 0.3230, accuracy_valid: 0.8921, 0.3700sec
epoch: 81, loss_train: 0.2226, accuracy_train: 0.9264, loss_valid: 0.3237, accuracy_valid: 0.8943, 0.3750sec
epoch: 82, loss_train: 0.2211, accuracy_train: 0.9287, loss_valid: 0.3223, accuracy_valid: 0.8958, 0.3720sec
epoch: 83, loss_train: 0.2199, accuracy_train: 0.9285, loss_valid: 0.3216, accuracy_valid: 0.8958, 0.3718sec
epoch: 84, loss_train: 0.2181, accuracy_train: 0.9293, loss_valid: 0.3234, accuracy_valid: 0.8943, 0.3798sec
epoch: 85, loss_train: 0.2174, accuracy_train: 0.9291, loss_valid: 0.3227, accuracy_valid: 0.8928, 0.3954sec
epoch: 86, loss_tra