# Setup Notebook

In [None]:
#@markdown import libraries
import os
import sys
import json
import logging
from typing import List
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import sqlalchemy as sqla
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from sqlalchemy.ext.automap import automap_base

# Setup Cloud Database

In [None]:
from google.colab import auth

auth.authenticate_user()

# Setup Google Cloud Project

In [None]:
#@markdown Please enter your GCP Project ID
project_id = "" #@param {type:"string"}
assert project_id, "Please enter your Google Project ID to continue"
!gcloud config set project {project_id}

In [None]:
# @markdown Set IAM policy binding
user_account = !gcloud auth list --filter=status:ACTIVE --format="value(account)"
print("Active User Account: ", user_account[0])

out = !gcloud projects add-iam-policy-binding {project_id} \
  --member=user:{user_account[0]} \
  --role="roles/cloudsql.client"

In [None]:
#@title Connect to Cloud SQL Instance

#@markdown Please fill in the both the Google Cloud region and name of your Cloud SQL instance. Once filled in, run the cell.

# Please fill in these values.
region = "us-central1" #@param {type:"string"}
instance_name = "" #@param {type:"string"}

assert region, "Please enter a Google Cloud region"
assert instance_name, "Please enter the name of your instance"

# check if Cloud SQL instance exists in the provided region
database_version = !gcloud sql instances describe {instance_name} --format="value(databaseVersion)"
if database_version[0].startswith("POSTGRES"):
  print("Existing Postgres Cloud SQL Instance found!")
else:
  print("No existing Cloud SQL instance found!")


In [None]:
instance_connection_name = f"{project_id}:{region}:{instance_name}"
print("Instance Connection Name: ", instance_connection_name)

# Connect to Database

In [None]:
#@title Install Cloud SQL connector
!{sys.executable} -m pip install -q cloud-sql-python-connector["pg8000"]

In [None]:
#@title Class Templates
from abc import ABC, abstractmethod

class DatabaseManager(ABC):
  def __init__(self):
    self._engine = None
    self._session = None

  @abstractmethod
  def setup(self):
    pass

  def get_session(self):
    return self._session

  def get_engine(self):
    return self._engine

  def close_session(self):
    self._session.close()


class TaskManager(ABC):
  def __init__(self):
    pass

  def list_tasks(self):
    pass

In [None]:
#@title Cloud SQL Database
#@markdown Run this cell
import reprlib
from datetime import datetime
from typing import List
import sqlalchemy as sqla
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import relationship
from sqlalchemy.orm import declarative_base

Base = declarative_base()

class ProjectConfigurations(Base):
  __tablename__ = 'project_configurations'

  project_id: Mapped[int] = mapped_column(sqla.Integer, primary_key=True, autoincrement=True)
  project_title: Mapped[str] = mapped_column(sqla.String(255))
  cloud_bucket_name: Mapped[str] = mapped_column(sqla.String(255))
  cloud_bucket_prefix: Mapped[str] = mapped_column(sqla.String(255))
  comma_separated_labels: Mapped[str] = mapped_column(sqla.String(255))
  max_annotation_per_example: Mapped[int] = mapped_column(sqla.Integer)
  completion_deadline = mapped_column(sqla.TIMESTAMP, default=datetime.utcnow)
  created_at = mapped_column(sqla.TIMESTAMP, default=datetime.utcnow)

  def __repr__(self)-> str:
    return (f'Project Configurations\n{"*" * 26} \n' +
            # f'project_id={self.project_id!r}\n' +
            f'project title={self.project_title!r}\n' +
            f'cloud_bucket_name={self.cloud_bucket_name!r}\n' +
            f'cloud_bucket_prefix={self.cloud_bucket_prefix!r}\n' +
            f'comma_separated_labels={self.comma_separated_labels!r}\n' +
            f'max_annotation_per_example={self.max_annotation_per_example!r}\n' +
            f'completion_deadline={self.completion_deadline!r}\n' +
            f'project creation date={self.created_at!r}')


class Annotator(Base):
  __tablename__ = 'annotators'

  annotator_id: Mapped[int] = mapped_column(sqla.Integer, primary_key=True,
                                            nullable=False, autoincrement=True)
  username: Mapped[str] = mapped_column(sqla.String(255))
  email: Mapped[str] = mapped_column(sqla.String(255), nullable=True)

  annotations = relationship("Annotation", back_populates="annotator")
  assigned_annotators = relationship("AssignedAnnotator", back_populates="annotator")

  def __repr__(self)-> str:
    return (f'Annotator(annotator_id={self.annotator_id!r},' +
            f'username={self.username!r}, ' +
            f'email={self.email!r}')


class Annotation(Base):
  __tablename__ = 'annotations'

  annotation_id: Mapped[int] = mapped_column(sqla.Integer, primary_key=True, autoincrement=True)
  label: Mapped[str] = mapped_column(sqla.String(60), nullable=False)
  example_id: Mapped[str] = mapped_column(sqla.String(255), sqla.ForeignKey('examples.example_id'), nullable=False)
  example: Mapped['Example'] = relationship("Example", back_populates='annotations')

  annotator_id: Mapped[int] = mapped_column(sqla.Integer, sqla.ForeignKey('annotators.annotator_id'), nullable=False)
  annotator: Mapped['Annotator'] = relationship("Annotator", back_populates='annotations')

  def __repr__(self)-> str:
    return (f'Annotation(annotation_id={self.annotation_id!r},' +
            f'label={self.label!r}, ' +
            f'annotator={reprlib.repr(self.annotator)}')


class Example(Base):
  __tablename__ = 'examples'

  example_id: Mapped[str] = mapped_column(sqla.String(255), nullable=False, primary_key=True)
  image_path: Mapped[str] = mapped_column(sqla.String(255), nullable=False)

  annotations: Mapped[List['Annotation']] = relationship("Annotation", back_populates="example")
  assigned_annotators: Mapped[List['AssignedAnnotator']] = relationship("AssignedAnnotator", back_populates="example")

  def __repr__(self)-> str:
    return (f'Example(example_id={self.example_id!r},' +
            f'image_path={self.image_path!r}, ' +
            f'annotations={reprlib.repr(self.annotations)}')


class AssignedAnnotator(Base):
  __tablename__ = 'assigned_annotators'

  assignment_id: Mapped[int] = mapped_column(sqla.Integer, primary_key=True, autoincrement=True)
  example_id: Mapped[str] = mapped_column(sqla.String(255), sqla.ForeignKey('examples.example_id'), nullable=False)
  annotator_id: Mapped[int] = mapped_column(sqla.Integer, sqla.ForeignKey('annotators.annotator_id'), nullable=False)

  example: Mapped['Example'] = relationship("Example", back_populates="assigned_annotators")
  annotator: Mapped['Annotator'] = relationship("Annotator", back_populates="assigned_annotators")

  def __repr__(self)-> str:
    return (f'AssignedAnnotator(assignment_id={self.assignment_id!r}, ' +
            f'example_id={self.example_id!r}, ' +
            f'annotator_id={self.annotator_id}')

In [None]:
#@markdown Database Manager
import logging
from sqlalchemy import create_engine, exc, inspect
from sqlalchemy.orm import sessionmaker, scoped_session
from google.cloud.sql.connector import Connector

class CloudDatabaseManager(DatabaseManager):
  def __init__(self, username: str,
               password: str,
               db_name: str,
               instance_connection_name: str):
    super().__init__()
    self.username = username
    self.password = password
    self.db_name = db_name
    self.instance_connection_name = instance_connection_name

  def setup(self):
    try:
      connector = Connector()
      conn = connector.connect(
          self.instance_connection_name,
          "pg8000",
          user=self.username,
          password=self.password,
          db=self.db_name
      )
      self._engine = create_engine("postgresql+pg8000://", creator=lambda: conn)
      Session = sessionmaker(bind=self._engine)
      self._session = scoped_session(Session)
      logging.info("Database Manager Initialized")
    except Exception as e:
      logging.error(f"Error setting up database: {e}")



class ExternalTaskManager(TaskManager):
  def __init__(self,
               username: str,
               password: str,
               db_name: str,
               instance_connection_name: str
               ):
    self.cloud_database_mgr = CloudDatabaseManager(
        username=username,
        password=password,
        db_name=db_name,
        instance_connection_name=instance_connection_name
      )
    self.username = username
    self.annotator_id = None

  def setup(self):
    self.cloud_database_mgr.setup()
    self._engine = self.cloud_database_mgr.get_engine()
    self._session = self.cloud_database_mgr.get_session()
    self.annotator_id = (
        self._session.query(Annotator)
        .filter_by(username=self.username)
        .first()
    ).annotator_id

  def list_tables(self):
    inspector = inspect(self._engine)
    tables = inspector.get_table_names()
    return tables

  def load_table_by_tablename(self, tablename: str):
    Base = automap_base()
    Base.prepare(autoload_with=self._engine)
    return Base.classes.get(tablename)

  def retrieve_assigned_tasks(self):
    assigned_examples = (
      self._session.query(Example)
        .join(AssignedAnnotator, Example.example_id == AssignedAnnotator.example_id)
        .join(Annotator, Annotator.annotator_id == AssignedAnnotator.annotator_id)
        .filter(Annotator.username == self.username)
        .all()
      )
    return assigned_examples

  def get_project_configs(self):
    return self._session.query(ProjectConfigurations).first()

  def export_annotation(self, example_id, label):
    session = self._session

    # Check if annotator is assigned to this task
    assignment_query = (
        session.query(AssignedAnnotator)
        .join(Annotator,
              Annotator.annotator_id == AssignedAnnotator.annotator_id)
        .filter(AssignedAnnotator.example_id == example_id)
        .filter(Annotator.username == self.username)
        .first()
    )

    if assignment_query:
      # Update existing annotation if it exists
      existing_annotation = (
          session.query(Annotation)
          .filter_by(example_id=example_id, annotator_id=self.annotator_id)
          .first()
      )
      if existing_annotation:
        try:
          existing_annotation.label = label
          session.commit()
          print(f"Update- example_id: {example_id} label: {label}")
        except Exception as e:
          print("Error updating annotation:", e)
          session.rollback()
      else:
        # Write new annotation
        try:
          new_annotation = Annotation(
              example_id=example_id,
              label=label,
              annotator_id=self.annotator_id
          )
          session.add(new_annotation)
          session.commit()
          print(f"New- example_id: {example_id} label: {label}")
        except Exception as e:
          print("Error adding new annotation:", e)
          session.rollback()
    else:
      print(f"No assignment found for example_id {example_id} and annotator {self.username}")

  def retrieve_annotation(self, example_id):
    session = self._session()
    annotation = session.query(Annotation).filter(
        Annotation.example_id == example_id, Annotation.annotator_id == self.annotator_id).scalar()
    return annotation.label if annotation is not None else None


In [None]:
#@markdown Enter name of Cloud SQL database to connect to
database_name = "" #@param {type: "string"}
assert database_name, "Please enter database name to connect"

#@markdown Enter username to access Cloud SQL Database
username = "" #@param {type: "string"}
assert username, "Please enter a valid username"

#@markdown Enter password to access Cloud SQL Database
password = "" #@param {type:"string"}

auth_configs = {'username': username,
                'password': password,
                'db_name': database_name,
                'instance_connection_name': instance_connection_name
                }

# extern_database_manager = AnnotatorDatabaseManager(**auth_configs)
# extern_database_manager.setup()
# extern_session = extern_database_manager.get_session()

## Init Cloud Task Manager

In [None]:
external_task_mgr = ExternalTaskManager(username, password, database_name, instance_connection_name)
external_task_mgr.setup()

In [None]:
external_task_mgr.list_tables()

# Get Project Configs

In [None]:
project_configs = external_task_mgr.get_project_configs()
project_configs

In [None]:
# @title preprocess labels
# This should be done in the Admin Notebook
def preprocess_labels(labels_string: str):
  labels = labels_string.split(",")
  processed_labels = []

  for label in labels:
    label = label.strip(" ")
    if label:
      processed_labels.append(label)
  return processed_labels

config_project_labels = preprocess_labels(project_configs.comma_separated_labels)
config_project_labels

In [None]:
#@title Set Local Project Configs
#@markdown Enter local storage path
local_images_path = "images" #@param {type: "string"}
assert local_images_path, "Please enter a pathname where images to be labeled are stored"

## Retrieve Assigned Tasks

In [None]:
assigned_examples = external_task_mgr.retrieve_assigned_tasks()

# Internal Database

In [None]:
#@markdown
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import ForeignKey
from sqlalchemy.orm import declarative_base

# Define Base class
SQLiteBase = declarative_base()

class LocalExample(SQLiteBase):
  __tablename__ = 'examples_table'

  id = Column(Integer, primary_key=True, autoincrement=True)
  example_id = Column(String(255), nullable=False, unique=True)
  image_path = Column(String(255), nullable=False)
  image_filename = Column(String(255), nullable=False)

  def __repr__(self) -> str:
    return (f'Example(example_id={self.example_id!r},' +
            f'image_path={self.image_path!r},' +
            f'image_name={self.image_filename!r}'
            )

# Internal Database Manager


In [None]:
#@markdown
from typing import List
import sqlalchemy as sqla
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.orm import declarative_base

class InternalDatabaseManager(DatabaseManager):
  def __init__(self, db_path: str, **kwargs: dict) -> None:
    super().__init__()
    self.db_path = db_path
    self.kwargs = kwargs
    self._engine = None
    self._session = None

  def setup(self) -> None:
    """Initialize the database."""
    self._engine = sqla.create_engine(f'sqlite:///{self.db_path}', **self.kwargs)
    SQLiteBase.metadata.bind = self._engine
    SQLiteBase.metadata.create_all(self._engine)

    session_factory = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
    self._session = scoped_session(session_factory)

    SQLiteBase.query = self._session.query_property()
    print("Database Manager Initialized")

  def status(self) -> None:
    """Check Database Status."""
    if self._engine is None:
        raise ValueError("Database Manager not initialized.")

  def drop_db(self) -> None:
    """Drop the database."""
    engine = self._engine
    SQLiteBase.metadata.drop_all(engine)

  # def get_session(self) -> sqla.orm.session.Session:
  #   """Get the database session."""
  #   return self._session

# Internal Task Manager

In [None]:
#@markdown
from google.cloud import storage
class InternalTaskManager:
  def __init__(self, assigned_examples, db_configs):
    self.assigned_examples = assigned_examples
    self.db_configs = db_configs
    self._database_manager = None
    self.session = None

  def init(self):
    self._database_manager = InternalDatabaseManager(**self.db_configs)
    self._database_manager.setup()

  def get_session(self):
    return self._database_manager.get_session()

  def get_remote_image_paths(self):
    return [example.image_path for example in self.assigned_examples]

  def download_assigned_images(
    self,
    bucket_name: str,
    bucket_prefix: str,
    output_path: str
  ) -> None:
    print("Downloading new examples")
    if not os.path.exists(output_path):
      os.makedirs(output_path)

    client = storage.Client()
    bucket = client.bucket(bucket_name)

    for example in self.assigned_examples:
      image_path = example.image_path
      image_filename = image_path.split('/')[-1]
      blob = bucket.blob(os.path.join(bucket_prefix, image_filename))
      image_filepath = os.path.join(output_path, image_filename)

      try:
        blob.download_to_filename(image_filepath)
        example.image_filename = image_filename
      except Exception as e:
        print(f"Error downloading image: {e}")

      self.add_example_to_database(example_id=example.example_id,
                    image_path=example.image_path,
                    image_filename=image_filename
                    )
    print("Done!")

  def get_all_examples(self) -> List[LocalExample]:
    """Get all examples from the database."""
    session = self.get_session()
    examples = session.query(LocalExample).all()
    return examples

  def get_labeled_examples(self):
    session = self.get_session()
    labeled_examples = session.query(LocalExample).where(LocalExample.label != None).all()
    return labeled_examples

  def get_example_by_id(self, id: int):
    """Get an example by its ID."""
    session = self.get_session()
    example = session.query(LocalExample).filter(LocalExample.example_id == id).scalar()
    return example

  def add_example_to_database(self, example_id, image_path, image_filename, label=None) -> None:
    """Add examples to the database."""
    session = self.get_session()
    existing_example = session.query(LocalExample).filter_by(example_id=str(example_id)).first()

    if not existing_example:
      example = LocalExample(example_id=example_id, image_path=image_path, image_filename=image_filename)
      if label:
        example.label = label
      session.add(example)
    else:
      session.rollback()
    session.commit()

  def update_label(self, image_filename, new_label) -> None:
    """Update the label of an example."""
    session = self.get_session()
    example = session.query(LocalExample).filter(LocalExample.image_filename == image_filename).first()
    if example:
      example.label = new_label
      session.commit()
      logging.info(f"Label updated for Image {image_filename}")
    else:
      logging.error(f"No example found with name {image_filename}")

  def get_label(self, image_filename):
    session = self.get_session()
    example = session.query(LocalExample).filter(LocalExample.image_filename == image_filename).first()
    return example.label if not None else None

  def update_example(self, example: LocalExample) -> None:
    """Update an example."""
    session = self.get_session()
    session.add(example)
    session.commit()

  def get_example_id_from_filename(self, filename: str):
    session = self.get_session()
    result=session.query(LocalExample).where(LocalExample.image_filename==filename).first()
    return result.example_id

  def delete_example(self, example: LocalExample) -> None:
    """Delete an example."""
    session = self.get_session()
    session.delete(example)
    session.commit()

In [None]:
db_configs = {'db_path': 'sqlite.db'}
internal_task_mgr = InternalTaskManager(assigned_examples, db_configs)
internal_task_mgr.init()

In [None]:
internal_task_mgr.download_assigned_images(project_configs.cloud_bucket_name, project_configs.cloud_bucket_prefix, local_images_path)

# Annotations Display UI

In [None]:
#@markdown
from ipywidgets import IntSlider
from ipywidgets import FloatSlider
from ipywidgets import Button
from ipywidgets import RadioButtons
from ipywidgets import VBox
from ipywidgets import HBox
from ipywidgets import interact
from PIL import Image

def update_label(selected_label):
  if selected_label['new'] != None:
    image_filename = images[image_index.value].split('/')[-1]
    internal_task_mgr.update_label(image_filename=image_filename, new_label=selected_label['new'])
    example_id = internal_task_mgr.get_example_id_from_filename(image_filename)
    external_task_mgr.export_annotation(example_id, selected_label['new'])

def load_images(local_images_path):
    return [f'{local_images_path}/{image}' for image in os.listdir(local_images_path)]

def display_image(image_index, size):
    image_path = images[image_index]
    plt.figure(figsize=(size, size))
    image = Image.open(image_path)
    plt.imshow(image)
    plt.axis('off')
    plt.show()

def get_current_label(index):
  image_filename = images[index].split('/')[-1]
  example_id = internal_task_mgr.get_example_id_from_filename(image_filename)
  label_radio_buttons.value = external_task_mgr.retrieve_annotation(example_id)

def on_next_button_click(b):
    if image_index.value < len(images) - 1:
        image_index.value += 1

def on_previous_button_click(b):
    if image_index.value > 0:
        image_index.value -= 1

def on_image_change(a):
  get_current_label(image_index.value)

# Define default values and widgets
local_images_path = 'images'
images = load_images(local_images_path)
image_index = IntSlider(min=0, max=len(images)-1, step=1, value=0)
size_slider = FloatSlider(value=10.0, min=1.0, max=20.0, step=0.1, description='Image Size:', continuous_update=False)
next_button = Button(description="Next")
previous_button = Button(description="Previous")
label_radio_buttons = RadioButtons(value=None, options=config_project_labels, layout={'width': 'max-content'}, description='String Label: \n', disabled=False)

image_filename = images[0].split('/')[-1]
example_id = internal_task_mgr.get_example_id_from_filename(image_filename)
label_radio_buttons.value = external_task_mgr.retrieve_annotation(example_id)

# Define widget interactions
next_button.on_click(on_next_button_click)
previous_button.on_click(on_previous_button_click)
label_radio_buttons.observe(update_label, names='value')
image_index.observe(on_image_change, names='value')

# If label_radio_button changes after clicking next or previous, do not make update
# Display UI
interact(display_image, image_index=image_index, size=size_slider)
navigation_buttons = HBox([previous_button, next_button])
ui_elements = HBox([label_radio_buttons, VBox([navigation_buttons])])
display(ui_elements)