## Building the Model and Loading the Weights

In [123]:
from zoobot.tensorflow.estimators import define_model, preprocess
from zoobot.tensorflow.data_utils import image_datasets
from zoobot.tensorflow.predictions import predict_on_dataset

import glob
from PIL import Image
import numpy as np
import cv2 as cv
import pandas as pd

In [7]:
model = define_model.get_model(
    34,
    300,
    224,
    224,
    which_maxvit = 'MaxViTTiniest',
    use_effnet = False
)



In [8]:
model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 perma_random_rotation_1 (Pe  (None, 300, 300, 1)      0         
 rmaRandomRotation)                                              
                                                                 
 perma_random_flip_1 (PermaR  (None, 300, 300, 1)      0         
 andomFlip)                                                      
                                                                 
 perma_random_crop_1 (PermaR  (None, 224, 224, 1)      0         
 andomCrop)                                                      
                                                                 
 maxvit (MaxViT)             (None, 1280)              8006460   
                                                                 
 top_dropout (PermaDropout)  (None, 1280)              0         
                                                      

In [9]:
checkpoint_path = r'C:\Users\oryan\Documents\AstroHack\data\vit_2xgpu\vit_2xgpu\checkpoint'

In [10]:
model.load_weights(checkpoint_path).expect_partial()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x2141650d888>

In [113]:
files = glob.glob('C:/Users/oryan/Documents/AstroHack/transformer/data/*.jpeg')[:50]

In [114]:
right_shape = []
for i in files:
    with Image.open(i) as im:
        im_arr = np.asarray(im)
    shap = im_arr.shape
    if shap[0] < 301:
        # print('Wrong shape!')
        right_shape.append(i)

In [115]:
file_format = 'jpeg'
initial_size = 300
batch_size = 8

In [116]:
raw_image_ds = image_datasets.get_image_dataset(
        [x for x in right_shape], file_format, initial_size, batch_size
    )

In [117]:
def process_images(ds, initial_size):
    preprocessing_config = preprocess.PreprocessingConfig(
        label_cols = [],
        input_size = initial_size,
        make_greyscale = True,
    )

    processed_ds = preprocess.preprocess_dataset(ds, preprocessing_config)

    return processed_ds

In [118]:
processed_images_ds = process_images(raw_image_ds, initial_size)



In [119]:
for images, labels in processed_images_ds.take(1):
    # images x(images.numpy().shape)
    pass

In [120]:
# model(images)

In [135]:
def make_predictions(images, model, save_folder):
    # cut = 0.65

    n_samples = 1
    label_cols = list(np.linspace(0,34,34).astype(str))
    save_loc = f'{save_folder}/predictions.csv'

    predict_on_dataset.predict(images, model, n_samples, label_cols, save_loc)

    predictions = pd.read_csv(f'{save_folder}/predictions.csv')

    # predictions_export = (
    #     predictions
    #     .assign(matchid = predictions.id_str.apply(lambda x: os.path.splitext(os.path.basename(x))[0]))
    #     .assign(base_binary_prediction = predictions.interacting_pred.apply(lambda x: 1 if ast.literal_eval(x)[0] > cut else 0))
    #     .assign(certain_binary_prediction = predictions.interacting_pred.apply(lambda x: 1 if ast.literal_eval(x)[0] > 0.95 else 0))
    #     .rename(columns={'id_str' : 'file_path'})
    # )

    return predictions

In [136]:
predictions = make_predictions(
    processed_images_ds,
    model,
    'C:/Users/oryan/Documents/AstroHack/transformer/results'
)

In [137]:
predictions = pd.read_csv(f'C:/Users/oryan/Documents/AstroHack/transformer/results/predictions.csv')

In [138]:
predictions

Unnamed: 0,id_str,0.0_pred,1.0303030303030303_pred,2.0606060606060606_pred,3.090909090909091_pred,4.121212121212121_pred,5.151515151515151_pred,6.181818181818182_pred,7.212121212121212_pred,8.242424242424242_pred,...,24.727272727272727_pred,25.757575757575758_pred,26.78787878787879_pred,27.818181818181817_pred,28.848484848484848_pred,29.87878787878788_pred,30.909090909090907_pred,31.939393939393938_pred,32.96969696969697_pred,34.0_pred
0,C:/Users/oryan/Documents/AstroHack/transformer...,[11.166573524475098],[3.148249626159668],[1.0847786664962769],[1.0687271356582642],[3.173170566558838],[1.0422924757003784],[1.7416069507598877],[1.018814206123352],[2.860600709915161],...,[3.322587490081787],[32.74802017211914],[1.1647144556045532],[1.0106602907180786],[1.0107676982879639],[9.32662296295166],[15.36890697479248],[3.8387486934661865],[1.0134378671646118],[1.0095880031585693]
1,C:/Users/oryan/Documents/AstroHack/transformer...,[6.788135528564453],[1.9600162506103516],[1.0637024641036987],[1.152871012687683],[19.53359603881836],[1.042995810508728],[1.3942131996154785],[1.008370280265808],[2.3976991176605225],...,[1.823885202407837],[10.90322208404541],[1.263911485671997],[1.0105137825012207],[1.0068806409835815],[6.6767168045043945],[19.883630752563477],[4.5675883293151855],[1.021529197692871],[1.005927324295044]
2,C:/Users/oryan/Documents/AstroHack/transformer...,[10.795567512512207],[3.0345747470855713],[1.1008212566375732],[1.0224202871322632],[1.6800919771194458],[1.0270754098892212],[3.187511444091797],[1.0083261728286743],[2.145723342895508],...,[5.158139228820801],[38.76292037963867],[1.1249372959136963],[1.0107742547988892],[1.0139683485031128],[11.054275512695312],[15.087864875793457],[3.0635783672332764],[1.0041499137878418],[1.0059078931808472]
3,C:/Users/oryan/Documents/AstroHack/transformer...,[3.5851845741271973],[1.6770044565200806],[1.0905841588974],[1.006962537765503],[1.3856629133224487],[1.0067074298858643],[4.139183044433594],[1.001949667930603],[1.3097683191299438],...,[2.656308650970459],[24.18233299255371],[1.0429153442382812],[1.0030498504638672],[1.0024315118789673],[6.51828145980835],[11.838818550109863],[1.8339256048202515],[1.0009729862213135],[1.0005024671554565]
4,C:/Users/oryan/Documents/AstroHack/transformer...,[4.012651443481445],[1.9774327278137207],[1.1133501529693604],[1.0053588151931763],[1.4117399454116821],[1.0078954696655273],[3.8324337005615234],[1.0010483264923096],[1.2193838357925415],...,[1.7188588380813599],[28.015888214111328],[1.0361979007720947],[1.0016998052597046],[1.0013679265975952],[4.7896037101745605],[8.754656791687012],[1.469595193862915],[1.0003888607025146],[1.0004844665527344]
5,C:/Users/oryan/Documents/AstroHack/transformer...,[3.2422103881835938],[1.6102979183197021],[1.0377357006072998],[1.002174973487854],[1.1437783241271973],[1.0045591592788696],[2.9534082412719727],[1.0009682178497314],[1.2169281244277954],...,[2.309535026550293],[29.77753448486328],[1.0172048807144165],[1.0017058849334717],[1.0017478466033936],[5.975432395935059],[8.881021499633789],[1.5374878644943237],[1.00049889087677],[1.0004048347473145]
6,C:/Users/oryan/Documents/AstroHack/transformer...,[4.36931037902832],[2.125666618347168],[1.0759303569793701],[1.0166505575180054],[2.192948341369629],[1.0153779983520508],[1.9383022785186768],[1.0015541315078735],[1.2722382545471191],...,[2.543548583984375],[28.068878173828125],[1.0584274530410767],[1.0031875371932983],[1.0031347274780273],[5.193803787231445],[6.937614917755127],[1.6314903497695923],[1.00187087059021],[1.0010030269622803]
7,C:/Users/oryan/Documents/AstroHack/transformer...,[2.1745123863220215],[1.2669990062713623],[1.0071488618850708],[1.0001693964004517],[1.0221967697143555],[1.0027180910110474],[2.0361075401306152],[1.0002059936523438],[1.1393556594848633],...,[2.7806291580200195],[25.329784393310547],[1.0071109533309937],[1.0003937482833862],[1.0004584789276123],[6.723080158233643],[7.127288818359375],[1.3583818674087524],[1.0001769065856934],[1.000079870223999]
8,C:/Users/oryan/Documents/AstroHack/transformer...,[3.1627755165100098],[1.4825949668884277],[1.0174760818481445],[1.0007286071777344],[1.08544921875],[1.0032256841659546],[2.2557148933410645],[1.0005828142166138],[1.235787034034729],...,[3.1471352577209473],[26.8687801361084],[1.0157177448272705],[1.0008854866027832],[1.0007878541946411],[6.705142021179199],[8.6661376953125],[1.5008786916732788],[1.0004699230194092],[1.000212550163269]
9,C:/Users/oryan/Documents/AstroHack/transformer...,[2.3629913330078125],[1.5766063928604126],[1.0323116779327393],[1.0029423236846924],[1.4476666450500488],[1.005659818649292],[1.904199481010437],[1.0005571842193604],[1.192040205001831],...,[1.9859174489974976],[20.636730194091797],[1.0243037939071655],[1.0015208721160889],[1.0014452934265137],[4.312623500823975],[7.4971208572387695],[1.5322635173797607],[1.000594973564148],[1.0001938343048096]
