-
Notifications
You must be signed in to change notification settings - Fork 1
/
prediction.py
30 lines (24 loc) · 1.02 KB
/
prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from fastai.core import *
from fastai.vision import *
import fastai.metrics as metrics
from torch import sum
from torch.nn import CrossEntropyLoss as CRE
from torch.nn.functional import relu
from .loss_functions import *
def get_pred(out):
_, concat_logits, _, _, _ = out
return concat_logits.argmax()
def predict(learn:Learner, item:ItemBase):
batch = learn.data.one_item(item)
res = learn.pred_batch(batch=batch)
raw_pred, x = grab_idx(res, 0),batch[0]
norm = getattr(learn.data, 'norm', False)
if norm:
x = learn.data.denorm(x)
if norm.keywords.get('do_y', False): pred = learn.data.denorm(pred)
ds = learn.data.single_ds
pred = get_pred(raw_pred)
out = ds.y.reconstruct(pred, ds.x.reconstruct(img.data)) if has_arg(ds.y.reconstruct, 'x') else ds.y.reconstruct(pred)
x = ds.x.reconstruct(grab_idx(x, 0))
y = ds.y.reconstruct(pred, x) if has_arg(ds.y.reconstruct, 'x') else ds.y.reconstruct(pred)
return (x, y, pred, raw_pred)