# 训练篇

1. 选择nn.loss损失函数
2. 选择torch.optim优化算法
3. 设置超参数
4. 设置tensorboard进行可视化

In [31]:
import torch 
import torch.nn as nn # 包含loss和各种blocks, layers
from torch.utils.data import * # 包括Dataset和DataLoader
from tensorboardX import SummaryWriter 

from resnet import * 
from generate_dataset import * 

配置如下：
1. settings: 训练集地址和验证集地址
2. writer for tensorboard
3. prepare dataloader
4. load model
5. loss func 
6. optim 
7. let's go 

然后过程如下:
```python
for ep in range(EPOCHS):
    training ...
    validation ...
testing ...
```



In [32]:
torch.cuda.is_available()

True

In [33]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f0d896dbcc0>

In [34]:
def train(model_dir):
    # path for training and testing data
    train_data_path = "/share/mal/malware/data/train"
    test_data_path = "/share/mal/malware/data/test"
    
    # writer for debug
    writer = SummaryWriter(comment="malware_classification%resnet34")
    
    # prepare dataloader
    train_set = generate_dataset(train_data_path)
    test_set = generate_dataset(test_data_path)
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=True)
    
    # load model
    model = ResNet(num_block=[3, 4, 6, 3], num_classes=8)
    model.cuda()
    model.train()
    
    # loss func
    criterion = nn.CrossEntropyLoss()
    
    # optim
    optim = torch.optim.SGD(model.parameters(), lr=1e-2)
    
    # let's go 
    EPOCH = 20
    step = 0
    best_acc = 0.0
    
    for ep in range(EPOCH):
        # training
        for idx, data in enumerate(train_loader):
            x, y = data
            y = torch.from_numpy(np.asarray(y, dtype=np.long))
            y-=1 # 训练集标签从1-9，映射到0-8
            x = x.cuda()
            y = y.cuda()
            
            y_pred = model(x)
            loss = criterion(y_pred, y)
            loss.backward()
            optim.step()
            optim.zero_grad()
            
            if step % 10 == 0:
                print ("epoch {} step {}: loss={}".format(ep, step, loss))
            writer.add_scalar("Loss", loss, step)
            step += 1
        
        # validation
        correct = 0
        total = 0
        model.eval()
        for idx, data in enumerate(test_loader):
            x, y = data
            y = torch.from_numpy(np.asarray(y, dtype=np.long))
            y-=1
            x = x.cuda()
            y = y.cuda()

            y_pred = model(x)
            prediction = torch.argmax(y_pred, 1)
            correct += (prediction == y).sum()
            total += len(y)
        
        acc = correct/total
        print ("epoch {}: acc={}".format(ep, acc))
        writer.add_scalar("acc", acc, ep)
        
        # save the epoch model and replace the best model
        model_path = os.path.join(model_dir, 'malware_classification%resnet34%{}.pth'.format(ep))
        torch.save(model, model_path)
        if acc > best_acc:
            best_model_path = os.path.join(model_dir, 'malware_classification%resnet34%best.pth')
            torch.save(model, best_model_path)
            best_acc = acc
            
    writer.close()

In [35]:
model_path = "/root/paperwithcode/第三周-训练篇/malware_dpcnn/trained_models/" 
train(model_path)

epoch 0 step 0: loss=2.076631784439087
epoch 0 step 10: loss=2.07023286819458
epoch 0 step 20: loss=2.0594799518585205
epoch 0 step 30: loss=2.0616228580474854
epoch 0 step 40: loss=2.0597167015075684
epoch 0 step 50: loss=2.045849084854126
epoch 0 step 60: loss=2.0501255989074707
epoch 0 step 70: loss=2.0405056476593018
epoch 0 step 80: loss=2.0442278385162354
epoch 0 step 90: loss=2.0241730213165283
epoch 0 step 100: loss=2.027639150619507
epoch 0 step 110: loss=2.0041229724884033
epoch 0 step 120: loss=2.0299246311187744
epoch 0 step 130: loss=2.015204668045044
epoch 0: acc=0.2918017506599426
epoch 1 step 140: loss=1.9951720237731934
epoch 1 step 150: loss=1.8323334455490112
epoch 1 step 160: loss=1.716463327407837
epoch 1 step 170: loss=1.7289382219314575
epoch 1 step 180: loss=1.7048345804214478
epoch 1 step 190: loss=1.7646193504333496
epoch 1 step 200: loss=1.7421860694885254
epoch 1 step 210: loss=1.778476357460022
epoch 1 step 220: loss=1.8079018592834473
epoch 1 step 230: los

In [36]:
# 保存最优模型字典
best_model_path = os.path.join(model_path, 'malware_classification%resnet34%best.pth')
state_dict_path = os.path.join(model_path, 'malware_classification%resnet34%best.pt')
model = ResNet(num_block=[3, 4, 6, 3])
best_model = torch.load(best_model_path)
torch.save(best_model.state_dict(), state_dict_path)