In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import pickle

# Step to load the model from the pickle file
def load_pickle(file_path):
    with open(file_path, 'rb') as file:
        obj = pickle.load(file)
    return obj

# Model Definition
class ResNet(nn.Module):
    @nn.compact
    def __call__(self, x, training=True):
        x = nn.Conv(64, (7, 7), strides=(2, 2), padding='SAME', use_bias=False)(x)
        # x = nn.BatchNorm(use_running_average=False)(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        for _ in range(5):
            x = ResBlock(64)(x)
        x = nn.avg_pool(x, (7, 7))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(4)(x)
        if not training:
            x = nn.softmax(x)
        return x

class ResBlock(nn.Module):
    features: int
    strides: tuple = (1, 1)
    @nn.compact
    def __call__(self, x):
        conv_args = {'use_bias': False, 'kernel_size': (3, 3), 'padding': 'SAME'}
        # norm = nn.BatchNorm(use_running_average=False)
        y = nn.Conv(self.features, **conv_args)(x)
        # y = norm(y)
        y = nn.relu(y)
        y = nn.Conv(self.features, strides=self.strides, **conv_args)(y)
        # y = norm(y)
        if x.shape != y.shape:
            x = nn.Conv(self.features, kernel_size=(1, 1), strides=self.strides, use_bias=False)(x)
            # x = norm(x)
        return nn.relu(y + x)

class Model:
  def __init__(self, nn_class, params, classes):
    self.model = nn_class()
    self.params = params
    self.classes = classes

  def apply_model(self, inputs):
    return self.model.apply({'params': self.params}, inputs, training=False)

  # Prediction function to be pickled
  def predict(self, input_image):
      input_processed = np.array(input_image)  # Add actual preprocessing steps as required
      input_processed = input_processed.reshape(1, 224, 224, 3)  # Example reshape
      preds = self.apply_model(jnp.array(input_processed))
      idx = int(np.argmax(preds, axis=1).item())  # Convert logits to class prediction
      return self.classes[idx]

model = load_pickle('ModelObj.pkl') # make sure you have the ModelObj.pkl file!

import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk

class SidebarApp(tk.Tk):
    def __init__(self):
        super().__init__()

        self.title("Eye classification")
        self.geometry("500x400")

        self.sidebar = tk.Frame(self, width=300, bg="blue")
        self.sidebar.pack(side="left", fill="y")

        self.home_button = tk.Button(self.sidebar, text="Home", command=self.go_to_home)
        self.home_button.pack(fill="x", padx=10, pady=10)

        self.button1 = tk.Button(self.sidebar, text="Classify", command=self.go_to_about)
        self.button1.pack(fill="x", padx=10, pady=10)

        self.main_frame = tk.Frame(self, bg="darkgray")
        self.main_frame.pack(side="right", fill="both", expand=True)

        self.label = tk.Label(self.main_frame, text="Welcome to the Eye Classification App", font=("Arial", 12), padx=20, pady=20, justify="left")
        self.label.pack()

        self.image_label = tk.Label(self.main_frame)

    def go_to_home(self):
        self.label.config(text="Welcome to the Eye Classification App")
        self.image_label.pack_forget()

    def go_to_about(self):
        self.label.config(text="Select a eye scan to classify")

        if hasattr(self, "read_file_button"):
            self.read_file_button.destroy()

        self.read_file_button = tk.Button(self.main_frame, text="Read Image File", command=self.read_file_on_about_page)
        self.read_file_button.pack()

        self.image_label.pack_forget()

    def read_file_on_about_page(self):
        file_path = filedialog.askopenfilename(title="Select a file", filetypes=[("Image files", "*.png;*.jpg;*.jpeg"), ("All files", "*.*")])
        if file_path:
            image = Image.open(file_path)

            data = np.asarray(image.resize((224, 224))).reshape(1,224,224,3)
            pred = model.predict(data) # string with name of predicted disease

            image = Image.fromarray(data)
            image.thumbnail((300, 300))
            photo = ImageTk.PhotoImage(image)
            self.image_label.config(image=photo)
            self.image_label.image = photo
            self.image_label.pack()

            # TODO: add pred to some text field
            # self.prediction_field.config(text=pred)
            # self.prediction_field.pack()

if __name__ == "__main__":
    app = SidebarApp()
    app.configure(bg="lightgray")
    app.mainloop()

TclError: no display name and no $DISPLAY environment variable