In [1]:
import torch as t, torch.nn as nn, torch.nn.functional as F, torch.distributions as tdist
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import random_split
import torchvision as tv, torchvision.transforms as tr
import os

import sys
import numpy as np
import json
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import dvc.api

In [2]:
class NNThree(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  #input tensors [x, 3, 32, 32]
            nn.Softplus(), #softplus is a different non-linear activation function, similar to ReLU
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  #input [ x, 32, 32, 32]
            nn.Softplus(),  #output [x, 64, 32, 32]
            nn.MaxPool2d(2, 2),  #output [x, 64, 16, 16]
            nn.BatchNorm2d(64), #, eps=1e-05, momentum=0.3, affine=True, track_running_stats=True),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), #out [1, 128, 16, 16]
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), #out [1, 128, 16, 16]
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), # out [1, 256, 8, 8]
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), # out [1, 256, 8, 8]
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 256 x 4 x 4
            nn.BatchNorm2d(256),
            
            nn.Flatten(),
            nn.Linear(256*4*4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
        
    def forward(self, xb):
        return self.network(xb)

In [3]:
PATH = './manualwideresnetCIFAR10.pth'

resource_url = dvc.api.get_url(PATH, repo='https://github.com/mjcurran/MLExamples')
resource_url

'gdrive://1KnAUlHs375IGNtM1sOePm8iR6rQbKrhB/7e/51a4a209a16d3a7492a9a89b1a1b0a'

In [10]:
import io

model = NNThree()
# torch.load needs a seekable "file" so read the remote model into a buffer first
with dvc.api.open('manualwideresnetCIFAR10.pth', repo='https://github.com/mjcurran/MLExamples', mode='rb') as f:
    buffer = io.BytesIO(f.read())
    
model.load_state_dict(
    t.load(buffer
          )
)

<All keys matched successfully>