Skip to content

Commit

Permalink
load torch files without torch
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Nov 21, 2020
1 parent 2689986 commit 03994e0
Showing 1 changed file with 73 additions and 4 deletions.
77 changes: 73 additions & 4 deletions extra/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,71 @@
from tinygrad.utils import fetch
from tinygrad.nn import BatchNorm2D

USE_TORCH = False

def fake_torch_load(b0):
import io
import pickle
import struct

# convert it to a file
fb0 = io.BytesIO(b0)

# skip three junk pickles
pickle.load(fb0)
pickle.load(fb0)
pickle.load(fb0)

key_prelookup = {}

class HackTensor:
def __new__(cls, *args):
#print(args)
ident, storage_type, obj_key, location, obj_size, view_metadata = args[0]
assert ident == 'storage'

ret = np.zeros(obj_size, dtype=storage_type)
key_prelookup[obj_key] = (storage_type, obj_size, ret, args[2], args[3])
return ret

class MyPickle(pickle.Unpickler):
def find_class(self, module, name):
#print(module, name)
if name == 'FloatStorage':
return np.float32
if name == 'LongStorage':
return np.int64
if module == "torch._utils" or module == "torch":
return HackTensor
else:
return pickle.Unpickler.find_class(self, module, name)

def persistent_load(self, pid):
return pid

ret = MyPickle(fb0).load()

# create key_lookup
key_lookup = pickle.load(fb0)
key_real = [None] * len(key_lookup)
for k,v in key_prelookup.items():
key_real[key_lookup.index(k)] = v

# read in the actual data
for storage_type, obj_size, np_array, np_shape, np_strides in key_real:
ll = struct.unpack("Q", fb0.read(8))[0]
assert ll == obj_size
bytes_size = {np.float32: 4, np.int64: 8}[storage_type]
mydat = fb0.read(ll * bytes_size)
np_array[:] = np.frombuffer(mydat, storage_type)
np_array.shape = np_shape

# numpy stores its strides in bytes
real_strides = tuple([x*bytes_size for x in np_strides])
np_array.strides = real_strides

return ret

class MBConvBlock:
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio):
oup = expand_ratio * input_filters
Expand Down Expand Up @@ -123,8 +188,6 @@ def forward(self, x):

def load_weights_from_torch(self, gpu):
# load b0
import io
import torch
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py#L551
if self.number == 0:
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
Expand All @@ -136,7 +199,13 @@ def load_weights_from_torch(self, gpu):
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth")
else:
raise Exception("no pretrained weights")
b0 = torch.load(io.BytesIO(b0))

if USE_TORCH:
import io
import torch
b0 = torch.load(io.BytesIO(b0))
else:
b0 = fake_torch_load(b0)

for k,v in b0.items():
if '_blocks.' in k:
Expand All @@ -150,7 +219,7 @@ def load_weights_from_torch(self, gpu):
mv = eval(mk.replace(".weight", ""))
except AttributeError:
mv = eval(mk.replace(".bias", "_bias"))
vnp = v.numpy().astype(np.float32)
vnp = v.numpy().astype(np.float32) if USE_TORCH else v
mv.data[:] = vnp if k != '_fc.weight' else vnp.T
if gpu:
mv.cuda_()
Expand Down

0 comments on commit 03994e0

Please sign in to comment.