Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
271 lines (231 sloc) 9.39 KB
# coding: utf8
from __future__ import unicode_literals
import contextlib
from ..compat import BytesIO
from ..neural._classes.model import Model
try:
import cupy
except ImportError:
cupy = None
try:
import torch.autograd
import torch.optim
import torch
import torch.utils.dlpack
from torch.nn import Module as PyTorchModule
except ImportError:
torch = None
PyTorchModule = None
def xp2torch(xp_tensor):
if hasattr(xp_tensor, "toDlpack"):
return torch.utils.dlpack.from_dlpack(xp_tensor.toDlpack())
else:
return torch.from_numpy(xp_tensor)
def torch2xp(torch_tensor):
if torch_tensor.is_cuda:
return cupy.fromDlpack(torch.utils.dlpack.to_dlpack(torch_tensor))
else:
return torch_tensor.detach().numpy()
class PyTorchWrapper(Model):
"""Wrap a PyTorch model, so that it has the same API as Thinc models.
To optimize the model, you'll need to create a PyTorch optimizer and call
optimizer.step() after each batch --- see examples/wrap_pytorch.py
"""
def __init__(self, model, conf=None):
Model.__init__(self)
self._model = model
self.conf = conf
self._optimizer = None
def begin_update(self, x_data, drop=0.0, **kwargs):
"""Return the output of the wrapped PyTorch model for the given input,
along with a callback to handle the backward pass.
"""
if self.conf is None:
x_var = torch.autograd.Variable(xp2torch(x_data), requires_grad=True)
y_var = self._model(x_var, **kwargs)
def backward_pytorch(dy_data, sgd=None):
dy_var = xp2torch(dy_data)
torch.autograd.backward((y_var,), grad_tensors=(dy_var,))
if sgd is not None:
if self._optimizer is None:
self._optimizer = self._create_optimizer(sgd)
self._optimizer.step()
self._optimizer.zero_grad()
return torch2xp(x_var.grad)
y = torch2xp(y_var)
else:
"""
self.grads specifies how the pytorch wrapper should behave
when multiple inputs come into play.
Let's say we have n inputs and m outputs
Args:
i_grad: (bool iterable)
controls which of n inputs require grad
o_xp: int or None
determines the number of outputs
b_map: list or None
maps each dYi with one or more Yj.
Leave none if there is only one gradient
ret_x: list of int
a list of indexes of the inputs that will be returned
in the backpropagation
"""
i_grad, o_xp, b_map, ret_x = self.conf
""" Input numpy arrays to tensors with grad or no grad """
x_var = [
xp2torch(x_data[i])
if not grad
else torch.autograd.Variable(xp2torch(x_data[i]), requires_grad=True)
for i, grad in enumerate(i_grad)
]
y_var = self._model(x_var, **kwargs)
""" Tensors to numpy arrays """
if o_xp is not None:
y = []
for i in range(o_xp):
y.append(torch2xp(y_var[i]))
y = tuple(y)
else:
y = torch2xp(y_var)
def backward_pytorch(dy_data, sgd=None):
if b_map is None:
# one gradient, one output
dy_var = xp2torch(dy_data)
torch.autograd.backward((y_var,), grad_tensors=(dy_var,))
else:
if len(b_map) == 1:
# this corresponds to one gradient and multiple outputs
dy_var = xp2torch(dy_data)
grad_list = b_map[0]
for y_indx in grad_list:
torch.autograd.backward(
(y_var[y_indx],),
grad_tensors=(dy_var,),
retain_graph=True,
)
else:
# this corresponds to multiple gradients
vars = []
for grad_indx, grad_list in enumerate(b_map):
dy_var = xp2torch(dy_data[grad_indx])
for y_indx in grad_list:
if o_xp is not None:
torch.autograd.backward(
(y_var[y_indx],), grad_tensors=(dy_var,)
)
else:
torch.autograd.backward(
(y_var,),
grad_tensors=(dy_var,),
retain_graph=True,
)
if sgd is not None:
if self._optimizer is None:
self._optimizer = self._create_optimizer(sgd)
self._optimizer.step()
self._optimizer.zero_grad()
b_out = []
for indx in ret_x:
if i_grad[indx]:
b_out.append(torch2xp(x_var[indx].grad))
else:
b_out.append(torch2xp(x_var[indx]))
if len(b_out) == 0:
return None
elif len(b_out) == 1:
return b_out[0]
else:
return b_out
return y, backward_pytorch
return torch2xp(y_var), backward_pytorch
def _create_optimizer(self, sgd):
params = self._model.parameters()
if sgd.b1 != 0 and sgd.b2 != 0:
optimizer = torch.optim.Adam(params, lr=sgd.alpha, betas=(sgd.b1, sgd.b2))
elif sgd.b2 == 0:
optimizer = torch.optim.SGD(params, lr=sgd.alpha, momentum=sgd.b1)
else:
raise NotImplementedError
return optimizer
def to_disk(self, path):
# TODO: Untested
torch.save(self._model.state_dict(), str(path))
def from_disk(self, path):
# TODO: Untested
self._model.load_state_dict(torch.load(path))
def to_bytes(self):
# TODO: Untested
filelike = BytesIO()
torch.save(self._model.state_dict(), filelike)
filelike.seek(0)
return filelike.getvalue()
def from_bytes(self, data):
# TODO: Untested
filelike = BytesIO(data)
filelike.seek(0)
self._model.load_state_dict(torch.load(filelike))
def to_gpu(self, device_num):
self._model.cuda(device_num)
def to_cpu(self):
self._model.cpu()
def resize_output(self, new_dim):
# self.weight = nn.Parameter(F.pad(self.weight, ...)) # add classes
# self.weight = nn.Parameter(F.pad(model.weight, ...)) # add classes
raise NotImplementedError
def resize_input(self):
raise NotImplementedError
@contextlib.contextmanager
def use_params(self, params): # pragma: no cover
if self.id in params:
backup = self.to_bytes()
self.from_bytes(params[self.id])
else:
backup = None
yield
if backup is not None:
self.from_bytes(backup)
class PyTorchWrapperRNN(PyTorchWrapper):
"""Wrap a PyTorch RNN model
"""
def __call__(self, x_data, h_0=None):
x_var = torch.autograd.Variable(xp2torch(x_data), requires_grad=False)
# Make prediction
out, h_n = self._model(x_var, h_0)
return (self.ops.asarray(out.data), h_n)
def begin_update(self, x_data, h_0=None, drop=0.0):
"""Return the output of the wrapped PyTorch model for the given input,
along with a callback to handle the backward pass.
"""
x_var = torch.autograd.Variable(xp2torch(x_data), requires_grad=True)
# Make prediction
out, h_n = self._model(x_var, h_0)
# Shapes will be:
# out = seq_len, batch, hidden_size * num_directions
# h_n = num_layers * num_directions, batch, hidden_size
def backward_pytorch_rnn(d_data, sgd=None):
dy_data, _ = d_data
dout = xp2torch(dy_data)
torch.autograd.backward((out,), grad_tensors=(dout,))
if sgd is not None:
if self._optimizer is None:
self._optimizer = self._create_optimizer(sgd)
self._optimizer.step()
self._optimizer.zero_grad()
return torch2xp(x_var.grad)
return (torch2xp(out), h_n), backward_pytorch_rnn
def resize_output(self, new_dim):
# self.weight = nn.Parameter(F.pad(self.weight, ...)) # add classes
# self.weight = nn.Parameter(F.pad(model.weight, ...)) # add classes
raise NotImplementedError
def resize_input(self):
raise NotImplementedError
@contextlib.contextmanager
def use_params(self, params): # pragma: no cover
if self.id in params:
backup = self.to_bytes()
self.from_bytes(params[self.id])
else:
backup = None
yield
if backup is not None:
self.from_bytes(backup)
You can’t perform that action at this time.