-
Notifications
You must be signed in to change notification settings - Fork 155
Unable to get eos_id properly after converting to torchscript #71
Copy link
Copy link
Closed
Description
I have converted the parseq.pt torch script model parseq.torchscript after basic testing I got issue on some images, model do not return correct eos_idx in C++ using parseq.torchscript model.
In python parseq.pt model
# Load image and prepare for input
image = Image.open(fname).convert('RGB')
image = img_transform(image).unsqueeze(0).to(args.device)
p = model(image).softmax(-1)
pred, p = model.tokenizer.decode(p)
print(ids)
print(probs)
print(f'{fname}: {pred[0]}')
output
tensor([ 6, 76, 1, 1, 9, 3, 4, 6, 7, 10, 2, 1, 1, 4, 0, 4, 3, 2,
4, 4, 7, 10, 1, 4, 4, 0])
tensor([0.9991, 0.8549, 0.9999, 0.9998, 0.9997, 0.9998, 0.9993, 0.9999, 0.9997,
0.9999, 0.9996, 0.9997, 0.9999, 0.9988, 0.9986, 0.3257, 0.2137, 0.2015,
0.3253, 0.4497, 0.2691, 0.7619, 0.2822, 0.4551, 0.7156, 0.8741])
cropped.jpeg: 5.008235691003
Converted parseq.pt to parseq.torchscript as below:
dummy_input = torch.rand(1, 3, 32, 128) # (1, 3, 32, 128) by default
traced_script_module = torch.jit.trace(model, dummy_input)
traced_script_module.save("parseq.torchscript")
In C++
......
......
tensor_image = tensor_image.toType(c10::kFloat).div(255);
tensor_image = transpose(tensor_image, {(2), (0), (1)});
tensor_image.unsqueeze_(0);
std::cout << "input shape : " << tensor_image.sizes() << std::endl;
std::vector<torch::jit::IValue> inputs;
inputs.push_back(tensor_image);
at::Tensor output = module.forward(inputs).toTensor();
output.softmax(-1);
std::tuple<at::Tensor, at::Tensor> probs_ids = output.max(-1);
std::cout << "ids : " << std::get<1>(probs_ids) << std::endl;
std::cout << "probs : " << std::get<0>(probs_ids) << std::endl;
std::string word;
at::Tensor ids = std::get<1>(probs_ids);
for (int c = 0; c < ids.sizes()[1]; c++)
{
int id = ids[0][c].item<int>();
if (id == 0)
{
break;
}
word += char_set[id - 1];
}
std::cout<< word << std::endl;
Output
ids : Columns 1 to 20 6 76 1 1 9 3 4 6 7 10 1 1 1 4 4 4 4 4 4 4
Columns 21 to 26 4 10 1 4 4 0
[ CPULongType{1,26} ]
probs : Columns 1 to 8 14.4657 11.2162 17.1043 17.0559 15.4707 15.9846 15.6380 15.6537
Columns 9 to 16 15.2115 13.4674 11.0046 13.6310 12.1103 13.6405 8.6431 6.6635
Columns 17 to 24 6.6002 5.9998 8.9770 10.0183 9.4945 11.0694 8.3291 8.9739
Columns 25 to 26 9.1511 8.8696
[ CPUFloatType{1,26} ]
5-00823569000333333339033
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels