In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

torch.set_printoptions(linewidth=120)
torch.set_grad_enabled(True)
from torch.utils.tensorboard import SummaryWriter
from IPython.display import display, clear_output

import time
import json
import pandas as pd

from collections import OrderedDict
from collections import namedtuple
from itertools import product

In [2]:
print(torch.__version__)
print(torchvision.__version__)


1.11.0
0.12.0


In [3]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [4]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, t):
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = self.flatten(t)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)

        return t

In [None]:
train_set = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(train_set, batch_size=100, shuffle=True)

In [18]:

class RunBuilder:
    @staticmethod
    def get_run(params):

        Run = namedtuple('Run', params.keys())

        runs = []

        for v in product(*params.values()):
            runs.append(Run(*v))
        return runs

In [40]:
class RunManager:
    def __init__(self):
        self.epoch_count = 0
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        self.epoch_start_time = None

        self.run_params = None
        self.run_count = 0
        self.run_data = []
        self.run_start_time = None

        self.network = None
        self.loader = None
        self.tb = None

    def begin_run(self, run, network, loader):
        self.run_start_time = time.time()

        self.run_params = run
        self.run_count += 1

        self.network = network
        self.loader = loader
        self.tb = SummaryWriter(comment=f'-{run}')

        images, labels = next(iter(self.loader))
        grid = torchvision.utils.make_grid(images)

        self.tb.add_image('images', grid)
        self.tb.add_graph(self.network, images.to(getattr(run, 'device', 'cpu')))

    def end_run(self):
        self.tb.close()
        self.epoch_count = 0

    def begin_epoch(self):
        self.epoch_start_time = time.time()

        self.epoch_count += 1
        self.epoch_loss = 0
        self.epoch_num_correct = 0

    def end_epoch(self):
        epoch_duration = time.time() - self.epoch_start_time
        run_druation = time.time() - self.run_start_time

        loss = self.epoch_loss/len(self.loader.dataset)
        accurary = self.epoch_num_correct/len(self.loader.dataset)

        self.tb.add_scalar('Loss', loss, self.epoch_count)
        self.tb.add_scalar('Accurary', accurary, self.epoch_count)

        for name, param in self.network.named_parameters():
            self.tb.add_histogram(name, param, self.epoch_count)
            self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count)

        results = OrderedDict()
        results["run"] = self.run_count
        results["epoch"] = self.epoch_count
        results["loss"] = loss
        results["accurary"] = accurary
        results["epoch duration"] = epoch_duration
        results["run duration"] = run_druation

        for k,v in self.run_params._asdict().items():results[k] = v
        self.run_data.append(results)
        df = pd.DataFrame.from_dict(self.run_data, orient='columns')

        clear_output(wait=True)
        display(df)

    def track_loss(self, loss):
        self.epoch_loss += loss.item() * self.loader.batch_size

    def track_num_correct(self, preds, labels):
        self.epoch_num_correct += self._get_num_correct(preds, labels)

    @torch.no_grad()
    def _get_num_correct(self, preds, labels):
        return preds.argmax(dim=1).eq(labels).sum().item()

    def save(self, filename):

        pd.DataFrame.from_dict(
            self.run_data
            ,orient='columns'
        ).to_csv(f'{filename}.csv')


        with open(f'{filename}.json', 'w', encoding='utf-8') as f:
            json.dump(self.run_data, f, ensure_ascii=False, indent=4)

In [41]:
prarmeters = OrderedDict(lr = [.01]
                ,batch_size = [1000, 20000]
                ,num_workers =[0, 1]
                ,device = ['cuda', 'cpu'])

m = RunManager()

for run in RunBuilder.get_run(prarmeters):
#for lr, batch_size, shuffle in parameter_value:
    #commect = f'batch_size={batch_size} lr={lr} shuffle={shuffle}'
    #commect = f'-{run}'
    device = torch.device(run.device)
    network = Network().to(device)
    loader = DataLoader(train_set, batch_size=run.batch_size, num_workers=run.num_workers)
    optimizer = optim.Adam(network.parameters(), lr = run.lr)

    m.begin_run(run, network, loader)
    # images, labels = next(iter(train_loader))
    # grid = torchvision.utils.make_grid(images)

    
    #tb = SummaryWriter(comment=commect)

    # tb.add_image('images', grid)
    # tb.add_graph(network, images)


    for epoch in range(1):
        # total_correct = 0
        # total_loss = 0
        m.begin_epoch()
        for batch in loader:
            #images, labels = batch
            images = batch[0].to(device)
            labels = batch[1].to(device)

            preds = network(images)
            loss = F.cross_entropy(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            m.track_loss(loss)
            m.track_num_correct(preds, labels)
            # total_loss += loss.item() * batch_size #此处考虑到不同批次的batch_size
            # total_correct += get_num_correct(preds, labels)
        m.end_epoch()
    m.end_run()
m.save('results')
        # tb.add_scalar('Loss', total_loss, epoch)
        # tb.add_scalar('Numver Correct', total_correct, epoch)
        # tb.add_scalar('Accurary', total_correct/len(train_set), epoch)

        # tb.add_histogram('conv1,bias', network.conv1.bias, epoch)
        # tb.add_histogram('conv1,weight', network.conv1.weight, epoch)
        # tb.add_histogram('conv1,weight.grad', network.conv1.weight.grad, epoch)
    #     for name, weight in network.named_parameters():
    #         tb.add_histogram(name, weight, epoch)
    #         tb.add_histogram(f'{name}.grad', weight.grad, epoch)
        

    #     print("epoch", epoch, "total_correct", total_correct, "total_loss", total_loss)



    # tb.close()

Unnamed: 0,run,epoch,loss,accurary,epoch duration,run duration,lr,batch_size,num_workers,device
0,1,1,1.004734,0.614367,7.554863,7.930863,0.01,1000,0,cuda
1,2,1,0.97899,0.625783,4.793993,5.175998,0.01,1000,0,cpu
2,3,1,1.04428,0.59395,2.51893,3.902678,0.01,1000,1,cuda
3,4,1,0.949709,0.640617,3.271995,4.471155,0.01,1000,1,cpu
4,5,1,2.279764,0.112733,2.317997,8.336851,0.01,20000,0,cuda
5,6,1,2.276822,0.138183,4.595999,11.576828,0.01,20000,0,cpu
6,7,1,2.289105,0.100567,2.71856,9.776948,0.01,20000,1,cuda
7,8,1,2.269968,0.151433,4.064014,12.479194,0.01,20000,1,cpu


In [42]:
pd.DataFrame.from_dict(m.run_data,orient='columns').sort_values('epoch duration')

Unnamed: 0,run,epoch,loss,accurary,epoch duration,run duration,lr,batch_size,num_workers,device
4,5,1,2.279764,0.112733,2.317997,8.336851,0.01,20000,0,cuda
2,3,1,1.04428,0.59395,2.51893,3.902678,0.01,1000,1,cuda
6,7,1,2.289105,0.100567,2.71856,9.776948,0.01,20000,1,cuda
3,4,1,0.949709,0.640617,3.271995,4.471155,0.01,1000,1,cpu
7,8,1,2.269968,0.151433,4.064014,12.479194,0.01,20000,1,cpu
5,6,1,2.276822,0.138183,4.595999,11.576828,0.01,20000,0,cpu
1,2,1,0.97899,0.625783,4.793993,5.175998,0.01,1000,0,cpu
0,1,1,1.004734,0.614367,7.554863,7.930863,0.01,1000,0,cuda
