# PyTorch Snippets
This repository will store useful and repetative code snippets for PyTorch.

# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable as var
from torch import FloatTensor as ft
from torch.utils.data import DataLoader
import torch.autograd as autograd

# Dataset Splitting
(split a Dataset object into 3 sub-datasets: train, val and test. adapted from [here](https://github.com/QuantScientist/Deep-Learning-Boot-Camp/blob/master/day02-PyTORCH-and-PyCUDA/PyTorch/21-PyTorch-CIFAR-10-Custom-data-loader-from-scratch.ipynb))

In [None]:
class FullTrainingDataset(torch.utils.data.Dataset):
    def __init__(self, full_ds, offset, length):
        self.full_ds = full_ds
        self.offset = offset
        self.length = length
        assert len(full_ds) >= offset + length, Exception("Parent Dataset not long enough")
        super(FullTrainingDataset, self).__init__()

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        return self.full_ds[i + self.offset]


def trainTestSplit(dataset, val_share=0.1, test_share=0.15):
    val_offset = int(len(dataset) * (1 - val_share - test_share))
    test_offset = int(len(dataset) * (1 - test_share))
    train_len = val_offset
    test_len = len(dataset) - test_offset
    val_len = len(dataset) - val_offset - test_len
    assert train_len + test_len + val_len == len(dataset)
    return FullTrainingDataset(dataset, 0, val_offset), \
           FullTrainingDataset(dataset, val_offset, val_len), \
           FullTrainingDataset(dataset, test_offset, test_len)

# Model Loading/Saving

In [None]:
def save_checkpoint(state, is_best, filename='models/checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'models/model_best.pth.tar')
        with open('models/model_best_summary.txt', 'wb') as f:
            f.write(state['summary'])
    print('==> saved {}'.format('(*)' if is_best else ''))


def load_checkpoint(model, optimizer, path):
    if not path:
        print("==> creating a new model")
        return 0, float("inf")
    if path == '-1':
        print("==> loading best model")
        path = 'models/model_best.pth.tar'
    if os.path.isfile(path):
        checkpoint = torch.load(path)
        start_epoch = checkpoint['epoch']
        best_loss = checkpoint['best_loss']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("==> loaded checkpoint '{}' (epoch {})"
              .format(path, checkpoint['epoch']))
        return start_epoch, best_loss
    else:
        print("==> no checkpoint found at '{}'".format(path))

# make checkpoint
is_best = val_loss < best_loss
best_loss = min(best_loss, val_loss)
save_checkpoint({
    'epoch': epoch + 1,
    'state_dict': model.state_dict(),
    'best_loss': best_loss,
    'optimizer': optimizer.state_dict(),
    'summary': str(model)
}, is_best)