In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from fastai import *
from fastai.vision import *
from torchvision.models import vgg16_bn
from tqdm import tqdm, tqdm_notebook
from crystal_clear.tensor_pipeline import TensorImageList
from fastai.metrics import accuracy

In [None]:
pipeline = 'tensor_pipeline'
path = Path(f'./data/crappified/dataset_1/{pipeline}')
if pipeline == 'tensor_pipeline':
    path_hr = path / 'orig_tensor'
if pipeline == 'image_pipeline':
    path_hr = path / 'orig_spectr'

In [None]:
if pipeline == 'tensor_pipeline':
    meta = pd.read_csv('./data/crappified/dataset_1/meta/meta_tensor.csv')
if pipeline == 'image_pipeline':
    meta = pd.read_csv('./data/crappified/dataset_1/meta/meta_spectr.csv')
meta.subset = meta.subset == 'valid'

In [None]:
meta.head()

In [None]:
vgg_orig = vgg16_bn(True)

In [None]:
new_head = torch.nn.Sequential(*(list(vgg_orig.classifier.children())[:-1] + [torch.nn.Linear(in_features=4096, out_features=8, bias=True)]))
new_head

In [None]:
vgg_orig.classifier = new_head

In [None]:
bs = 32
if pipeline == 'tensor_pipeline':
    data_stats = torch.load(path / 'data_stats.pkl')
    data = TensorImageList.from_df(meta, path_hr).split_from_df('subset').label_from_df('genre').databunch(bs=bs).normalize(data_stats)
if pipeline == 'image_pipeline':
    from fastai.vision.data import imagenet_stats
    data_stats = imagenet_stats
    data = ImageList.from_df(meta, path_hr).split_from_df('subset').label_from_df('genre').databunch(bs=bs).normalize(data_stats)

In [None]:
learn = Learner(data, vgg_orig, metrics=[accuracy])

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
max_lr = 1e-3

In [None]:
learn.fit_one_cycle(1,max_lr = max_lr)

In [None]:
save_name = 'model_1_fastai'
model_ft = learn.model
torch.save(model_ft, path / 'model_clas' / f'{save_name}.pth')
torch.save(model_ft.state_dict(), path / 'model_clas' / f'{save_name}_weights.pth')