In [1]:
import os
import numpy as np
from PIL import Image
import torchvision
from torchvision import transforms
import torch
import pickle
import torch.nn as nn
from tqdm import tqdm
from datasets.robot_hand import RobotHandDataset

In [2]:
torch.cuda.is_available()

True

In [3]:
dataset = RobotHandDataset('train', '/mnt/d/data/robot_hand', mode='tail')
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)

In [4]:
model = torchvision.models.resnet50()
model.conv1 = nn.Conv2d(10, 64, kernel_size=7, stride=2, padding=3,bias=False)
model.fc = nn.Linear(2048, 12)

print(len(data_loader))
for img, label in data_loader:
    print(img.shape)

    output = model(img)
    print(output)
    break

1019
torch.Size([1, 12, 224, 224])


RuntimeError: Given groups=1, weight of size [64, 10, 7, 7], expected input[1, 12, 224, 224] to have 10 channels, but got 12 channels instead

In [5]:
def train(logpath='', dataroot='/mnt/d/data/robot_hand', num_epochs=10):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    model = torchvision.models.resnet50()
    model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3,bias=False)
    model.fc = nn.Linear(2048, 12)
    model.to(device)

    dataset = RobotHandDataset(split='train', dataroot=dataroot, mode='head')
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=4)
    dataset_eval = RobotHandDataset(split='train', dataroot=dataroot, mode='tail')
    data_loader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=4)
        
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=0.01)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    len_dataloader = len(data_loader)

    for epoch in range(num_epochs):
        model.train()
        i = 0    
        epoch_loss = 0
        tqdm_bar = tqdm(data_loader, desc="Epoch {}/{}".format(epoch+1, num_epochs))
        for imgs, labels in tqdm_bar:
            i += 1
            # imgs = list(img.to(device) for img in imgs)
            # annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
            imgs = imgs.to(device)
            labels = labels.to(device) * 1000
            output = model(imgs)
            loss = nn.MSELoss()(output, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step() 
            tqdm_bar.set_description(f'Iteration: {i}/{len_dataloader}, Loss: {loss}')
            epoch_loss += loss
        
        # get eval set loss
        with torch.no_grad():
            model.eval()
            eval_loss = 0
            print('Evaluating...')
            for imgs, labels in data_loader_eval:
                imgs = imgs.to(device)
                labels = labels.to(device) * 1000
                output = model(imgs)
                loss = nn.MSELoss()(output, labels)
                eval_loss += loss
        
        scheduler.step()
        print(f'Epoch: {epoch+1}, Loss: {epoch_loss}, Eval Loss: {eval_loss}')

        torch.save(model.state_dict(), f'logs/{epoch+1}.pth')

train(num_epochs=50)

Iteration: 595/595, Loss: 0.15864497423171997: 100%|██████████| 595/595 [01:31<00:00,  6.50it/s]  


Evaluating...
Epoch: 1, Loss: 18694.013671875, Eval Loss: 94641.4296875


Iteration: 595/595, Loss: 0.16369278728961945: 100%|██████████| 595/595 [01:28<00:00,  6.71it/s]  


Evaluating...
Epoch: 2, Loss: 150.60545349121094, Eval Loss: 83723.578125


Iteration: 595/595, Loss: 0.11464127153158188: 100%|██████████| 595/595 [01:23<00:00,  7.15it/s]  


Evaluating...
Epoch: 3, Loss: 90.78055572509766, Eval Loss: 48492.13671875


Iteration: 595/595, Loss: 0.3086736798286438: 100%|██████████| 595/595 [01:22<00:00,  7.21it/s]   


Evaluating...
Epoch: 4, Loss: 79.41600036621094, Eval Loss: 33624.2265625


Iteration: 595/595, Loss: 1.4206066131591797: 100%|██████████| 595/595 [01:22<00:00,  7.19it/s]  


Evaluating...
Epoch: 5, Loss: 225.18190002441406, Eval Loss: 19693.099609375


Iteration: 595/595, Loss: 0.02577737532556057: 100%|██████████| 595/595 [01:25<00:00,  6.96it/s] 


Evaluating...
Epoch: 6, Loss: 121.9406509399414, Eval Loss: 17858.20703125


Iteration: 595/595, Loss: 0.40619027614593506: 100%|██████████| 595/595 [01:21<00:00,  7.32it/s] 


Evaluating...
Epoch: 7, Loss: 116.31401062011719, Eval Loss: 11007.7880859375


Iteration: 595/595, Loss: 0.05773354694247246: 100%|██████████| 595/595 [01:22<00:00,  7.24it/s]  


Evaluating...
Epoch: 8, Loss: 365.8838806152344, Eval Loss: 19597.58203125


Iteration: 595/595, Loss: 0.3353491723537445: 100%|██████████| 595/595 [01:22<00:00,  7.19it/s]   


Evaluating...
Epoch: 9, Loss: 73.0303726196289, Eval Loss: 21650.224609375


Iteration: 595/595, Loss: 0.754601001739502: 100%|██████████| 595/595 [01:22<00:00,  7.23it/s]    


Evaluating...
Epoch: 10, Loss: 79.70487976074219, Eval Loss: 40647.42578125


Iteration: 595/595, Loss: 0.020663024857640266: 100%|██████████| 595/595 [01:23<00:00,  7.13it/s]


Evaluating...
Epoch: 11, Loss: 252.41171264648438, Eval Loss: 30.05008316040039


Iteration: 595/595, Loss: 0.08569537103176117: 100%|██████████| 595/595 [01:22<00:00,  7.20it/s]  


Evaluating...
Epoch: 12, Loss: 43.400489807128906, Eval Loss: 40.024662017822266


Iteration: 595/595, Loss: 0.02946864441037178: 100%|██████████| 595/595 [01:21<00:00,  7.31it/s]  


Evaluating...
Epoch: 13, Loss: 29.758384704589844, Eval Loss: 185.08993530273438


Iteration: 595/595, Loss: 0.01033632829785347: 100%|██████████| 595/595 [01:22<00:00,  7.22it/s]  


Evaluating...
Epoch: 14, Loss: 95.96411895751953, Eval Loss: 1.1052576303482056


Iteration: 595/595, Loss: 0.5034350156784058: 100%|██████████| 595/595 [01:22<00:00,  7.21it/s]   


Evaluating...
Epoch: 15, Loss: 52.117210388183594, Eval Loss: 433.30792236328125


Iteration: 595/595, Loss: 0.12651538848876953: 100%|██████████| 595/595 [01:22<00:00,  7.24it/s]  


Evaluating...
Epoch: 16, Loss: 77.17578887939453, Eval Loss: 147.63270568847656


Iteration: 595/595, Loss: 0.005603165365755558: 100%|██████████| 595/595 [01:22<00:00,  7.24it/s] 


Evaluating...
Epoch: 17, Loss: 40.221153259277344, Eval Loss: 32.4567756652832


Iteration: 595/595, Loss: 0.03326454758644104: 100%|██████████| 595/595 [01:21<00:00,  7.31it/s]   


Evaluating...
Epoch: 18, Loss: 38.404666900634766, Eval Loss: 23.114063262939453


Iteration: 595/595, Loss: 0.1687859743833542: 100%|██████████| 595/595 [01:21<00:00,  7.27it/s]   


Evaluating...
Epoch: 19, Loss: 60.693748474121094, Eval Loss: 355.5068359375


Iteration: 595/595, Loss: 0.060936711728572845: 100%|██████████| 595/595 [01:21<00:00,  7.34it/s] 


Evaluating...
Epoch: 20, Loss: 34.00840759277344, Eval Loss: 18.239953994750977


Iteration: 595/595, Loss: 0.04931841045618057: 100%|██████████| 595/595 [01:21<00:00,  7.27it/s]  


Evaluating...
Epoch: 21, Loss: 37.946346282958984, Eval Loss: 15.891458511352539


Iteration: 595/595, Loss: 0.001061872928403318: 100%|██████████| 595/595 [01:21<00:00,  7.35it/s]  


Evaluating...
Epoch: 22, Loss: 357.7290344238281, Eval Loss: 81.07437133789062


Iteration: 595/595, Loss: 0.009917383082211018: 100%|██████████| 595/595 [01:22<00:00,  7.23it/s] 


Evaluating...
Epoch: 23, Loss: 3.606483221054077, Eval Loss: 83.5923080444336


Iteration: 595/595, Loss: 0.02075779065489769: 100%|██████████| 595/595 [01:21<00:00,  7.28it/s]   


Evaluating...
Epoch: 24, Loss: 8.686079025268555, Eval Loss: 46.43870162963867


Iteration: 595/595, Loss: 0.028987077996134758: 100%|██████████| 595/595 [01:22<00:00,  7.23it/s] 


Evaluating...
Epoch: 25, Loss: 30.73595428466797, Eval Loss: 40.61736297607422


Iteration: 595/595, Loss: 0.020345497876405716: 100%|██████████| 595/595 [01:22<00:00,  7.21it/s]  


Evaluating...
Epoch: 26, Loss: 28.346540451049805, Eval Loss: 105.5958023071289


Iteration: 595/595, Loss: 0.011800741776823997: 100%|██████████| 595/595 [01:21<00:00,  7.32it/s]  


Evaluating...
Epoch: 27, Loss: 32.4326171875, Eval Loss: 120.74703979492188


Iteration: 595/595, Loss: 0.0170077383518219: 100%|██████████| 595/595 [01:22<00:00,  7.20it/s]    


Evaluating...
Epoch: 28, Loss: 31.9710693359375, Eval Loss: 128.02923583984375


Iteration: 595/595, Loss: 0.047818806022405624: 100%|██████████| 595/595 [01:21<00:00,  7.26it/s] 


Evaluating...
Epoch: 29, Loss: 31.85393714904785, Eval Loss: 68.69621276855469


Iteration: 595/595, Loss: 0.02238159440457821: 100%|██████████| 595/595 [01:21<00:00,  7.26it/s]   


Evaluating...
Epoch: 30, Loss: 29.497655868530273, Eval Loss: 55.337677001953125


Iteration: 595/595, Loss: 0.0002157724229618907: 100%|██████████| 595/595 [01:21<00:00,  7.31it/s] 


Evaluating...
Epoch: 31, Loss: 0.8462767004966736, Eval Loss: 78.9841537475586


Iteration: 595/595, Loss: 0.0006551884580403566: 100%|██████████| 595/595 [01:21<00:00,  7.29it/s] 


Evaluating...
Epoch: 32, Loss: 1.101682424545288, Eval Loss: 92.3084487915039


Iteration: 595/595, Loss: 0.0010587768629193306: 100%|██████████| 595/595 [01:22<00:00,  7.22it/s] 


Evaluating...
Epoch: 33, Loss: 4.845229625701904, Eval Loss: 99.25418853759766


Iteration: 595/595, Loss: 0.006986238062381744: 100%|██████████| 595/595 [01:21<00:00,  7.28it/s]  


Evaluating...
Epoch: 34, Loss: 8.499953269958496, Eval Loss: 69.56427001953125


Iteration: 595/595, Loss: 0.0031044594943523407: 100%|██████████| 595/595 [01:21<00:00,  7.26it/s] 


Evaluating...
Epoch: 35, Loss: 8.099329948425293, Eval Loss: 83.0823745727539


Iteration: 595/595, Loss: 0.006416121032088995: 100%|██████████| 595/595 [01:22<00:00,  7.18it/s]  


Evaluating...
Epoch: 36, Loss: 7.672605037689209, Eval Loss: 84.1366195678711


Iteration: 595/595, Loss: 0.0009159412584267557: 100%|██████████| 595/595 [01:21<00:00,  7.32it/s] 


Evaluating...
Epoch: 37, Loss: 8.912328720092773, Eval Loss: 83.48954010009766


Iteration: 595/595, Loss: 0.0008936125668697059: 100%|██████████| 595/595 [01:23<00:00,  7.13it/s] 


Evaluating...
Epoch: 38, Loss: 7.337975025177002, Eval Loss: 84.13411712646484


Iteration: 595/595, Loss: 0.034925803542137146: 100%|██████████| 595/595 [01:21<00:00,  7.34it/s]  


Evaluating...
Epoch: 39, Loss: 7.859512805938721, Eval Loss: 176.3374481201172


Iteration: 595/595, Loss: 0.0010940118227154016: 100%|██████████| 595/595 [01:21<00:00,  7.32it/s] 


Evaluating...
Epoch: 40, Loss: 7.994284152984619, Eval Loss: 146.49478149414062


Iteration: 61/595, Loss: 0.013846304267644882:  10%|█         | 61/595 [00:08<01:18,  6.81it/s]  


KeyboardInterrupt: 

In [6]:
import yaml

config_file = yaml.load(open('./experiments/hrnet/hrnet.yaml', 'r'), Loader=yaml.FullLoader)
config_file

{'experiment_name': 'hrnet',
 'model_name': 'hrnet',
 'batch_size': 24,
 'epoch': 210,
 'lr': 0.001,
 'lr_factor': 0.1,
 'lr_step': [170, 200],
 'resume': None}