Skip to content

Commit

Permalink
Merge pull request #105 from rainyl/main
Browse files Browse the repository at this point in the history
🔨 fix the image resize in pix2tex.call_model
  • Loading branch information
lukas-blecher committed Mar 9, 2022
2 parents b5217d2 + 9f510b5 commit fd271d2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
5 changes: 3 additions & 2 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import torch.nn as nn
import torch.nn.functional as F

from x_transformers import *
from x_transformers.autoregressive_wrapper import *
# from x_transformers import *
from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper, top_k, top_p, entmax, ENTMAX_ALPHA
from timm.models.vision_transformer import VisionTransformer
from timm.models.vision_transformer_hybrid import HybridEmbed
from timm.models.resnetv2 import ResNetV2
Expand Down
6 changes: 4 additions & 2 deletions pix2tex.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def initialize(arguments=None):
args = parse_args(Munch(params))
args.update(**vars(arguments))
args.wandb = False
# args.device = "cpu"
args.device = 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu'

model = get_model(args)
Expand Down Expand Up @@ -82,9 +83,10 @@ def call_model(args, model, image_resizer, tokenizer, img=None):
if image_resizer is not None and not args.no_resize:
with torch.no_grad():
input_image = img.convert('RGB').copy()
r, w = 1, input_image.size[0]
r, w, h = 1, input_image.size[0], input_image.size[1]
for _ in range(10):
img = pad(minmax_size(input_image.resize((w, int(input_image.size[1]*r)), Image.BILINEAR if r > 1 else Image.LANCZOS), args.max_dimensions, args.min_dimensions))
h = int(h * r) # height to resize
img = pad(minmax_size(input_image.resize((w, h), Image.BILINEAR if r > 1 else Image.LANCZOS), args.max_dimensions, args.min_dimensions))
t = test_transform(image=np.array(img.convert('RGB')))['image'][:1].unsqueeze(0)
w = (image_resizer(t.to(args.device)).argmax(-1).item()+1)*32
logging.info(r, img.size, (w, int(input_image.size[1]*r)))
Expand Down

0 comments on commit fd271d2

Please sign in to comment.