In [None]:
from research_lib.utils.data_access_utils import S3AccessUtils
from keras.models import load_model
import torch
from torch import nn

In [None]:
s3 = S3AccessUtils('/root/data')

In [None]:
class Network(nn.Module):
    """Network class defines neural-network architecture for both weight and k-factor estimation
    (currently both neural networks share identical architecture)."""

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        """Run inference on input keypoint tensor."""
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.output(x)
        return x

pytorch_weight_estimation_model_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/biomass/trained_models/2020-11-27T00-00-00/weight_model_synthetic_data.pb'
pytorch_weight_estimation_model_f, _, _ = s3.download_from_url(pytorch_weight_estimation_model_url)
pytorch_weight_estimation_model = Network()
pytorch_weight_estimation_model.load_state_dict(torch.load(pytorch_weight_estimation_model_f))



In [None]:
def convert_to_tf(pytorch_model):
    
    # load placeholder TF model
    tf_weight_estimation_model = load_model('/root/data/alok/biomass_estimation/playground/model_keras_replicate_original_prod_v2.h5')
    tf_weight_estimation_model.layers[1].set_weights([pytorch_model.fc1.weight.data.T, 
                                                      pytorch_model.fc1.bias.data])
    tf_weight_estimation_model.layers[2].set_weights([pytorch_model.fc2.weight.data.T, 
                                                      pytorch_model.fc2.bias.data])
    tf_weight_estimation_model.layers[3].set_weights([pytorch_model.fc3.weight.data.T, 
                                                      pytorch_model.fc3.bias.data])
    tf_weight_estimation_model.layers[4].set_weights([pytorch_model.output.weight.data.T, 
                                                      pytorch_model.output.bias.data])
    
    return tf_weight_estimation_model
    

In [None]:
tf_model = convert_to_tf(pytorch_weight_estimation_model)

In [None]:
tf_model.save('/root/data/alok/biomass_estimation/playground/weight_model_synthetic_data.h5')

In [None]:
bucket, key = 'aquabyte-models', 'biomass/trained_models/2020-11-27T00-00-00/weight_model_synthetic_data.h5'
s3.s3_client.upload_file('/root/data/alok/biomass_estimation/playground/weight_model_synthetic_data.h5', 
                        bucket, 
                        key)

<h1> Perform Pytorch -> Keras conversion for K-factor model </h1>

In [None]:
pytorch_kf_estimation_model_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/k-factor/trained_models/2020-08-08T000000/kf_predictor_v2.pb'
pytorch_kf_estimation_model_f, _, _ = s3.download_from_url(pytorch_kf_estimation_model_url)
pytorch_kf_estimation_model = Network()
pytorch_kf_estimation_model.load_state_dict(torch.load(pytorch_kf_estimation_model_f))


In [None]:
tf_kf_model = convert_to_tf(pytorch_kf_estimation_model)

In [None]:
tf_kf_model.save('/root/data/alok/biomass_estimation/playground/kf_predictor_v2.h5')

In [None]:
bucket, key = 'aquabyte-models', 'k-factor/trained_models/2020-08-08T000000/kf_predictor_v2.h5'
s3.s3_client.upload_file('/root/data/alok/biomass_estimation/playground/kf_predictor_v2.h5', bucket, key)

In [None]:
pytorch_kf_estimation_model_url_1 = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/k-factor/playground/kf_predictor_v2.pb'
pytorch_kf_estimation_model_f_1, _, _ = s3.download_from_url(pytorch_kf_estimation_model_url_1)
pytorch_kf_estimation_model_1 = Network()
pytorch_kf_estimation_model_1.load_state_dict(torch.load(pytorch_kf_estimation_model_f_1))


In [None]:
pytorch_kf_estimation_model_url_2 = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/k-factor/trained_models/2020-08-08T000000/kf_predictor_v2.pb'
pytorch_kf_estimation_model_f_2, _, _ = s3.download_from_url(pytorch_kf_estimation_model_url_2)
pytorch_kf_estimation_model_2 = Network()
pytorch_kf_estimation_model_2.load_state_dict(torch.load(pytorch_kf_estimation_model_f_2))
