In [None]:
from fastai.vision.all import *
# from utils import *

In [None]:
path = Path('/root/Documents/images')

In [None]:
path.ls()

In [None]:
df_path = Path('/root/Documents/')
# df_path = Path('/workspace/Notebooks/')
# df = pd.read_csv(df_path/"species1000-df.csv")

In [None]:
df = pd.read_csv(df_path/'csv/families_3.csv')

In [None]:
weights_df = pd.read_csv(df_path/'csv/families3-weights.csv')
weights_family_df = pd.read_csv(df_path/'csv/families3-weights-family.csv')

In [None]:
weights_df

# HD-CNN

In [None]:
n_classes = len(set(df['Specie']))

In [None]:
model = resnet50
body = create_body(model, cut=-2)
body

In [None]:
stem = body[0:4]
stem

In [None]:
block1, block2, block3, block4 = body[4], body[5], body[6], body[7]
block1

In [None]:
head = create_head(2048*2,n_classes)
head

In [None]:
def custom_get_y(o):
    fine_label = o['Specie']
    coarse1_label = o['Family']
    return [coarse1_label, fine_label]

In [None]:
def count_classes_per_coarse(labels_df):
    result = {}
    for c in set(labels_df['coarse']):
        classes_with_c = labels_df[labels_df['coarse']==c]
        result[c] = len(classes_with_c)
    return result

In [None]:
labels_df = pd.DataFrame.from_dict({'fine': df['Specie'].values, 'coarse': df['Family'].values})
labels_df

In [None]:
classes_per_coarse = count_classes_per_coarse(labels_df)
classes_per_coarse

In [None]:
class Resnet50CustomModel(Module):
    def __init__(self, model, classes_per_coarse):
        n_coarse_classes = len(classes_per_coarse.keys())
        self.shared_layers = create_body(model, cut=-2)[:6]
        self.coarse_component = create_body(model, cut=-2)[6:]
        self.coarse_head = create_head(2048*2,n_coarse_classes)
        self.fine_components = nn.ModuleList([create_body(model, cut=-2)[6:] for _ in range(n_coarse_classes)])
        self.fine_heads = nn.ModuleList([create_head(2048*2, 144) for num_classes in classes_per_coarse.values()])
        

    def forward(self, x):
        x = self.shared_layers(x)
        coarse_x = self.coarse_component(x)
        coarse1_label = self.coarse_head(coarse_x)
        coarse_idx = coarse1_label[0].argmax()
        x = self.fine_components[coarse_idx](x)
        fine_label = self.fine_heads[coarse_idx](x)
        return {
            'fine_label': fine_label,
            'coarse1_label': coarse1_label
        }

In [None]:
model = Resnet50CustomModel(resnet50, classes_per_coarse)

In [None]:
class CustomCategorize(DisplayedTransform):
    "Reversible transform of category string to `vocab` id"
    loss_func,order=CrossEntropyLossFlat(),1
    def __init__(self, vocab=None, vocab_coarse1=None, vocab_coarse2=None, sort=True, add_na=False, num_y=1):
        store_attr()
        self.vocab = None if vocab is None else CategoryMap(vocab, sort=sort, add_na=add_na)
        self.vocab_coarse1 = None if vocab_coarse1 is None else CategoryMap(vocab_coarse1, sort=sort, add_na=add_na)

    def setups(self, dsets):
        fine_dsets = [d[1] for d in dsets]
        coarse1_dsets = [d[0] for d in dsets]
        if self.vocab is None and dsets is not None: self.vocab = CategoryMap(fine_dsets, sort=self.sort, add_na=self.add_na)
        if self.vocab_coarse1 is None and dsets is not None: self.vocab_coarse1 = CategoryMap(coarse1_dsets, sort=self.sort, add_na=self.add_na)
        self.c = len(self.vocab)

    def encodes(self, o): return {'fine_label': TensorCategory(self.vocab.o2i[o[1]]),
                                  'coarse1_label': TensorCategory(self.vocab_coarse1.o2i[o[0]])
                                 }
    def decodes(self, o): return Category      (self.vocab    [o])

In [None]:
def CustomCategoryBlock(vocab=None, sort=True, add_na=False, num_y=1):
    "`TransformBlock` for single-label categorical targets"
    return TransformBlock(type_tfms=CustomCategorize(vocab=vocab, sort=sort, add_na=add_na))

In [None]:
def custom_splitter(model):
    return [params(model.shared_layers),
            params(model.coarse_component),
            params(model.fine_components),
            params(model.coarse_head),
            params(model.fine_heads)]

In [None]:
fishes = DataBlock(blocks = (ImageBlock, CustomCategoryBlock),
                 splitter=ColSplitter(),
                 get_x = ColReader(5, pref=path),
                 get_y=custom_get_y,
                 item_tfms=RandomResizedCrop(336, min_scale=0.5),
                 batch_tfms=aug_transforms())
dls = fishes.dataloaders(df)

In [None]:
dls.train_ds, dls.valid_ds

In [None]:
weights = tensor([float(weights_df[weights_df['Specie']==c]['Weight']) for c in dls.vocab]).cuda()
weights_family = tensor([float(weights_family_df[weights_family_df['Family']==c]['Weight']) for c in dls.vocab_coarse1]).cuda()

In [None]:
def loss_func(out, targ):
    return nn.CrossEntropyLoss(weight=weights)(out['fine_label'], targ['fine_label']) + \
            nn.CrossEntropyLoss(weight=weights_family)(out['coarse1_label'], targ['coarse1_label'])

In [None]:
def custom_accuracy(inp, targ, axis=-1):
    pred1,targ1 = flatten_check(inp['fine_label'].argmax(dim=axis), targ['fine_label'])
    acc1 = (pred1 == targ1).float().mean()
    return acc1

In [None]:
learn = Learner(dls, model, loss_func=loss_func, metrics=custom_accuracy,
                   splitter=custom_splitter).to_fp16()
learn.freeze()

In [None]:
learn.summary()

In [None]:
learn.fit(1, 3e-3)

In [None]:
learn.save('species144-hdcnn-resnet50-fepochs1-uepochs0')

In [None]:
def continue_training(pth_filename):
    learn = Learner(dls, model, loss_func=loss_func, metrics=custom_accuracy,
                   splitter=custom_splitter).to_fp16()
    learn.load(pth_filename)
    learn.unfreeze()
    return learn

In [None]:
learn.unfreeze()
learn.fit(10, slice(1e-6,1e-4))

In [None]:
learn.save('species144-hdcnn-resnet50-fepochs1-uepochs10')

In [None]:
learn = continue_training('species144-hdcnn-resnet50-fepochs1-uepochs10')

In [None]:
learn.fit(10, slice(1e-6,1e-4))
learn.save('species144-hdcnn-resnet50-fepochs1-uepochs20')

In [None]:
learn = continue_training('species144-hdcnn-resnet50-fepochs1-uepochs20')

In [None]:
learn.fit(10, slice(1e-6,1e-4))
learn.save('species144-hdcnn-resnet50-fepochs1-uepochs30')

In [None]:
learn = continue_training('species144-hdcnn-resnet50-fepochs1-uepochs30')

In [None]:
learn.fit(10, slice(1e-6,1e-4))
learn.save('species144-hdcnn-resnet50-fepochs1-uepochs40')

In [None]:
learn = continue_training('species144-hdcnn-resnet50-fepochs1-uepochs40')

In [None]:
learn.fit(10, slice(1e-6,1e-4))
learn.save('species144-hdcnn-resnet50-fepochs1-uepochs50')

In [None]:
learn.export('species144-hdcnn-resnet50-fepochs1-uepochs50')