# オフライン学習

学習データからモデル作成、オフライン(windows notePC)でも可

In [None]:
TASK = 'off001'

### 学習データ読み込み

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from xy_dataset import XYDataset


CATEGORIES = ['apex']

DATASETS = ['A', 'B', 'REMARK']

TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

datasets = {}
for name in DATASETS:
    datasets[name] = XYDataset('data/' + TASK + '_' + name, CATEGORIES, TRANSFORMS, random_hflip=True)

In [None]:
import cv2
import ipywidgets
import traitlets
from IPython.display import display
from ipywidgets import Button, Layout, Textarea, HBox, VBox

# initialize active dataset
dataset = datasets[DATASETS[0]]

# create widgets
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset')
category_widget = ipywidgets.Dropdown(options=dataset.categories, description='category')
count_widget = ipywidgets.IntText(description='count')

# manually update counts at initialization
count_widget.value = dataset.get_count(category_widget.value)

# sets the active dataset
def set_dataset(change):
    global dataset
    dataset = datasets[change['new']]
    count_widget.value = dataset.get_count(category_widget.value)
dataset_widget.observe(set_dataset, names='value')

# update counts when we select a new category
def update_counts(change):
    count_widget.value = dataset.get_count(change['new'])
category_widget.observe(update_counts, names='value')


data_collection_widget = ipywidgets.VBox([
    dataset_widget,
    category_widget,
    count_widget
])

#display(data_collection_widget)

### ベースモデル準備(RESNET)

In [None]:
import torch
import torchvision
import time

if torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'

#device = torch.device('cuda')
device = torch.device(DEVICE)

output_dim = 2 * len(dataset.categories)  # x, y coordinate for each category

# ALEXNET
#model = torchvision.models.alexnet(pretrained=True)
#model.classifier[-1] = torch.nn.Linear(4096, output_dim)

# SQUEEZENET 
# model = torchvision.models.squeezenet1_1(pretrained=True)
# model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)
# model.num_classes = len(dataset.categories)

# RESNET 18
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, output_dim)

# RESNET 34
# model = torchvision.models.resnet34(pretrained=True)
# model.fc = torch.nn.Linear(512, output_dim)

# DENSENET 121
# model = torchvision.models.densenet121(pretrained=True)
# model.classifier = torch.nn.Linear(model.num_features, output_dim)

model = model.to(device)

model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value=TASK+'_model.pth')

def load_model(c):
    start = time.time()
    write_log(model_path_widget.value + "の読込処理を開始します。")
    model.load_state_dict(torch.load('data/'+TASK+'_A/'+model_path_widget.value))
    process_time = time.time() - start
    write_log(model_path_widget.value + "の読込処理を終了しました(処理時間:{0:.3f}秒)。".format(process_time))
model_load_button.on_click(load_model)
    
def save_model(c):
    start = time.time()
    write_log(model_path_widget.value + "の書込処理を開始します。")
    torch.save(model.state_dict(), 'data/'+TASK+'_A/'+model_path_widget.value)
    process_time = time.time() - start
    write_log(model_path_widget.value + "の書込処理を終了しました(処理時間:{0:.3f}秒)。".format(process_time))
model_save_button.on_click(save_model)

model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button]),
])


#display(model_widget)

### 学習

In [None]:
l = Layout(flex='0 1 auto', height='100px', min_height='100px', width='auto')
process_widget = ipywidgets.Textarea(description='ログ', value='', layout=l)

process_no = 0
def write_log(msg):
    global process_widget, process_no
    process_no = process_no + 1
    process_widget.value = str(process_no) + ": " + msg + "\n" + process_widget.value

In [None]:
import matplotlib.pyplot as plt

BATCH_SIZE = 8

optimizer = torch.optim.Adam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')
loss_history= []
    
def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, model_save_button, model_load_button, accuracy_widget, loss_widget, progress_widget
    
    try:
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        model_save_button.disabled = True
        model_load_button.disabled = True
        train_button.disabled = True
        eval_button.disabled = True
        time.sleep(1)
        start = time.time()

        if is_training:
            model = model.train()
            write_log(str(epochs_widget.value)+"Epochの学習を開始します。")
        else:
            # ドロップアウトを無効にする評価モード
            model = model.eval()
            write_log("評価モードで学習を開始します(1 Epochのみ)。")

        ecoch_num = 1
        while epochs_widget.value > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            epoch_start = time.time()
            for images, category_idx, xy in iter(train_loader):
                # send data to device
                images = images.to(device)
                xy = xy.to(device)

                if is_training:
                    # zero gradients of parameters
                    optimizer.zero_grad()

                # execute model to get outputs
                outputs = model(images)

                # compute MSE loss over x, y coordinates for associated categories
                loss = 0.0
                for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):
                    loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)
                loss /= len(category_idx)

                if is_training:
                    # run backpropogation to accumulate gradients
                    loss.backward()

                    # step optimizer to adjust parameters
                    optimizer.step()

                # increment progress
                count = len(category_idx.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
                
            if is_training:
                process_time = time.time() - epoch_start
                write_log(str(ecoch_num)+"Epoch目の学習が終了しました(処理時間:{0:.3f}秒)。".format(process_time))
                epochs_widget.value = epochs_widget.value - 1
                ecoch_num = ecoch_num + 1
                loss_history.append(loss_widget.value)
            else:
                break
    except e:
        pass
    model = model.eval()
    
    model_save_button.disabled = False
    model_load_button.disabled = False
    train_button.disabled = False
    eval_button.disabled = False
   
    process_time = time.time() - start
    if is_training:
        write_log("すべての学習が終了しました(トータルの処理時間:{0:.3f}秒)。".format(process_time))
    else:
        write_log("すべての学習が終了しました(トータルの処理時間:{0:.3f}秒)。".format(process_time))
    
    plt.close()
    plt.plot(loss_history)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.ylim(0,0.005)
        
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    ipywidgets.HBox([train_button])
])

#display(train_eval_widget)

### 実行

「train」で学習後、「save model」で保存。

追加学習の場合は最初に「load model」で読み込み後、「train」で学習「save model」で保存。

In [None]:
all_widget = ipywidgets.VBox([
    ipywidgets.HBox([data_collection_widget]), 
    train_eval_widget,
    model_widget,
    process_widget
])

display(all_widget)