<a target="_blank" href="https://colab.research.google.com/github/instadeepai/SKAInnotate/blob/main/Admin_Notebook.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# SKAInnotate - Admin Notebook

This is a data annotation tool enabling teams to annotate data from colab notebook.
This Admin-Notebook allows the following key features,
* Setting up a new annotation project
* Setting project configs
* Adding and assigning annotators to tasks
* Exporting annotations

In [None]:
#@title Install Cloud SQL connector
import sys
!{sys.executable} -m pip install -q cloud-sql-python-connector["pg8000"]

In [None]:
#@markdown import libraries
import os
import json
import logging
from functools import partial
from abc import ABC, abstractmethod

import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import reprlib
import pandas as pd
from typing import List
from datetime import datetime

import sqlalchemy as sqla
from sqlalchemy import create_engine
from sqlalchemy import exc
from sqlalchemy.sql import text

from sqlalchemy.orm import Session
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import scoped_session

import google.cloud.storage as storage
from google.cloud.sql.connector import Connector

## Authenticate to Google Colab

Authentication in Google Colab provides access to Google services requiring authentication such as Google Cloud Storage and Google Cloud SQL.

In [None]:
from google.colab import auth
auth.authenticate_user()

## Setup Google Cloud Project

In [None]:
#@markdown Please enter your GCP Project ID
project_id = "" #@param {type:"string"}
assert project_id, "Please enter your Google Project ID to continue"
!gcloud config set project {project_id}

In [None]:
#@markdown Set IAM policy binding and enable Cloud SQL Admin API
user_account = !gcloud auth list --filter=status:ACTIVE --format="value(account)"
print("Active User Account: ", user_account[0])

out=!gcloud projects add-iam-policy-binding {project_id} \
  --member=user:{user_account[0]} \
  --role="roles/cloudsql.client"

!gcloud services enable sqladmin.googleapis.com

In [None]:
#@title Create Cloud SQL Instance

#@markdown Enter the region of your Google Cloud SQL instance.

# Please fill in these values.
region = "" #@param {type:"string"}

#@markdown Enter the name of your Google Cloud SQL instance.
instance_name = "" #@param {type:"string"}

assert region, "Please enter a Google Cloud region"
assert instance_name, "Please enter the name of your instance"

#@markdown Enter a password to be used for 'postgres' database user
root_password = "" #@param {type:"string"}
assert root_password, "Please enter a password for 'postgres' database user"
# check if Cloud SQL instance exists in the provided region
database_version = !gcloud sql instances describe {instance_name} --format="value(databaseVersion)"
if database_version[0].startswith("POSTGRES"):
  print("Found existing Postgres Cloud SQL Instance!")
else:
  print("Creating new Cloud SQL instance...")
  !gcloud sql instances create {instance_name} --database-version=POSTGRES_15 \
    --region={region} --cpu=1 --memory=4GB --root-password={root_password} \
    --database-flags=cloudsql.iam_authentication=On


In [None]:
instance_connection_name = f"{project_id}:{region}:{instance_name}"
print("Instance Connection Name: ", instance_connection_name)

## Create Database

A new database is created if not already existing. This will store schemas and tables for the annotation task.

In [None]:
#@markdown Please Enter database name to create
database_name = "" #@param {type: "string"}
assert database_name, "Please enter a name for the database to be created"

!gcloud sql databases create {database_name} --instance={instance_name}

In [None]:
##@title Task Manager

#@markdown Setup_lib.py

class DatabaseManagerBase(ABC):
  def __init__(self):
    self._engine = None
    self._session = None

  @abstractmethod
  def setup(self):
    pass

  def get_session(self):
    return self._session

  def get_engine(self):
    return self._engine

  def close_session(self):
    self._session.close()


class TaskManagerBase(ABC):
  def __init__(self):
    pass

  def list_tasks(self):
    pass

class DatabaseManager(DatabaseManagerBase):
  def __init__(
      self,
      username: str,
      password: str,
      db_name: str,
      instance_connection_name: str,
      **kwargs: dict):

    self.username = username
    self.password = password
    self.db_name = db_name
    self.instance_connection_name = instance_connection_name

    self.kwargs = kwargs
    self._engine = None
    self._session = None

  def setup(self):
    connector = Connector()
    def getconn():
      conn = connector.connect(
          self.instance_connection_name,
          "pg8000",
          user=self.username,
          password=self.password,
          db=self.db_name
      )
      return conn

    self._engine = sqla.create_engine(
    "postgresql+pg8000://",
    creator=getconn,
    )
    Base.metadata.bind = self._engine
    try:
      Base.metadata.create_all(self._engine)
    except exc.ProgrammingError as e:
      raise # 'Error creating Database'

    session_factory = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
    self._session = scoped_session(session_factory)

    Base.query = self._session.query_property()
    logging.info("Database Manager Initialized")

  def get_session(self):
    return self._session


class TaskManager(TaskManagerBase):
  def __init__(self,
               username: str,
               password: str,
               db_name: str,
               instance_connection_name: str
               )-> None:
    self._database_manager = DatabaseManager(
        username=username,
        password=password,
        db_name=db_name,
        instance_connection_name=instance_connection_name
      )
    self.username = username
    # self._database_manager = None
    # self._task_configs = None
    # pass

  def setup(self):
    self._database_manager.setup()
    self._engine = self._database_manager.get_engine()
    self._session = self._database_manager.get_session()
    pass

  @staticmethod
  def download_csv_from_bucket(bucket_name, bucket_prefix, csv_filename):
    temp_file = 'metadata.csv'
    client = storage.Client()
    bucket = client.bucket(bucket_name=bucket_name)
    blob = bucket.blob(os.path.join(bucket_prefix, csv_filename))
    blob.download_to_filename(temp_file)

    return temp_file

  def add_tasks(self, bucket_name, bucket_prefix, csv_filename=None):
    session = self._session()
    metadata_file = self.download_csv_from_bucket(bucket_name, bucket_prefix, csv_filename)

    if csv_filename:
      self._add_tasks_from_csv(session, metadata_file)

  @staticmethod
  def _add_tasks_from_csv(session, csv_path: str):
    '''
    Adding Tasks from CSV metadata
    Args:
      csv_path: path to CSV file with metadata
    Return:
      None
    '''
    tasks_df = pd.read_csv(csv_path)

    for _, row in tasks_df.iterrows():
      example_id = row['example_id']
      existing_task = session.query(Example).filter_by(example_id=str(example_id)).first()

      if existing_task is None:
        task = Example(example_id=example_id, image=row['image'])
        session.add(task)
        session.commit()
      else:
        logging.info(f"Task with example_id '{example_id}' already exists. Skipping.")
    logging.info("\n")

  def set_project_configs(
      self,
      project_title: str = None,
      cloud_bucket_name: str = None,
      cloud_bucket_prefix: str = None,
      comma_separated_labels: str = None,
      max_annotators_per_example: str = None,
      completion_deadline: datetime = None
      )-> None:

    assert project_title, 'Please Enter project title'
    assert cloud_bucket_name, 'Please Enter source bucket name'
    assert comma_separated_labels, 'Please Enter task labels'

    session = self._session()
    existing_config = session.query(ProjectConfigurations).filter_by(project_id=1).first()

    if existing_config:
      if project_title: existing_config.project_title = project_title
      if cloud_bucket_name: existing_config.cloud_bucket_name = cloud_bucket_name
      if cloud_bucket_prefix: existing_config.cloud_bucket_prefix = cloud_bucket_prefix
      if comma_separated_labels: existing_config.comma_separated_labels = comma_separated_labels
      if max_annotators_per_example: existing_config.max_annotation_per_example = max_annotators_per_example
      if completion_deadline: existing_config.completion_deadline = completion_deadline

    else:
      project_configs = ProjectConfigurations(
        project_title=project_title,
        cloud_bucket_name=cloud_bucket_name,
        cloud_bucket_prefix=cloud_bucket_prefix,
        comma_separated_labels=comma_separated_labels,
        max_annotation_per_example=max_annotators_per_example
      )
      session.add(project_configs)
    session.commit()

  def assign_tasks(self, max_annotators_per_example):
    # Get all available examples and available annotators
    # Assign examples to annotators such that each example
    # is assigned up to a MAX_ANNOTATORS_PER_TASK number of times

    session = self._session()
    assignments = []

    # Query humans and tasks
    annotators = session.query(Annotator).all()
    examples = session.query(Example).all()
    num_annotators = len(annotators)
    num_examples = len(examples)

    # Round Robin Assignment Algorithm

    for i in range(num_examples):
      example = examples[i]
      for j in range(max_annotators_per_example):
        annotator = annotators[(i * max_annotators_per_example + j) % num_annotators]
        assignments.append(
          AssignedAnnotator(annotator_id=annotator.annotator_id,
            example_id=example.example_id)
            )
    session.add_all(assignments)
    session.commit()
    return assignments

  def show_assignments(self, limit=None):
    session = self._session()
    assigned_annotators = session.query(AssignedAnnotator).all()

    limit = limit or len(assigned_annotators)
    for assigned_annotator in assigned_annotators[:limit]:
      print(f'Example ID: {assigned_annotator.example_id}, assigned to Annotator: {assigned_annotator.annotator.username}')

  def get_task_assignments(self):
    session = self._session()
    results = sqla.select(
      Example.example_id, Example.image, Annotator.username
    ).join(Annotator)
    return results.all()

  def get_partially_assigned_tasks(self):
    '''Get tasks that have partially been assigned'''
    session = self._session()
    partially_labeled_tasks = (
      session.query(Example)
      .outerjoin(Example.annotations)
      .group_by(Example.example_id)
      # .having(sqla.func.count(Annotator.id) < self.tasks_configs.max_annotators_per_task)
      .all()
    )
    return partially_labeled_tasks

  def add_annotators(self, annotators):
    session = self._session()

    for annotator_data in annotators:
      username = annotator_data.get('username')
      email = annotator_data.get('email')

      existing_annotator = session.query(Annotator).filter_by(username=username).first()

      if existing_annotator is None:
        annotator = Annotator(username=username, email=email)
        try:
          session.add(annotator)
          session.commit()
          logging.info(f"Added new annotator '{username}'")
        except Exception as e:
          session.rollback()
          logging.error(f"Error adding annotator '{username}': {e}")
      else:
        logging.warning(f"Annotator with username '{username}' already exists")

  def remove_annotator(self, username, email=None):
    session = self._session()
    return Annotator.query.filter(Annotator.username == username).delete()

  def get_assigned_tasks(self, username):
    session = self._session()

    results = session.query(Example).\
      join(AssignedAnnotator, Example.example_id == AssignedAnnotator.example_id).\
      join(Annotator, AssignedAnnotator.annotator_id == Annotator.annotator_id).\
      filter(Annotator.username == username).all()

    return results

  def list_tasks(self, limit=None):
    session = self._database_manager.get_session()
    results = session.query(Example).all()
    limit = limit or len(results)
    return results[:limit]

  def list_annotators(self):
    session = self._session()
    return session.query(Annotator).all()

  def get_completed_annotations(self):
    # Define the select statement with joins
    session = self._session()
    stmt = (
        sqla.select(
            Example.example_id.label("EXAMPLE ID"),
            Annotation.label.label("LABEL"),
            Annotator.username.label("ANNOTATOR USERNAME")
        )
        .select_from(
            sqla.join(Example, Annotation, Example.example_id == Annotation.example_id)
            .join(Annotator, Annotation.annotator_id == Annotator.annotator_id)
        )
        .order_by(Annotator.annotator_id)
    )

    # Execute the statement
    result = session.execute(stmt)

    # Print the results
    return result

  def get_project_configs(self):
    session = self._session()
    return session.query(ProjectConfigurations).first()

  def grant_annotator_access(self, username):
    session = self._session()
    session.execute(text(f'GRANT SELECT ON ALL TABLES IN SCHEMA public TO {username};'))
    session.execute(text(f'GRANT INSERT, UPDATE on annotations TO {username};'))
    session.execute(text(f'GRANT USAGE, SELECT ON SEQUENCE annotations_annotation_id_seq TO {username};'))
    session.commit()

  def create_roles(self, role_name):
    session = self._session()
    session.execute(text(f"CREATE ROLE {role_name} WITH LOGIN PASSWORD 'skai_1234';"))
    return session.execute(text('SELECT * FROM pg_roles;')).all()

  def rollback_session(self):
    session = self._session()
    session.rollback()


## Database

In [None]:
#@markdown Run this cell

Base = declarative_base()

class ProjectConfigurations(Base):
  __tablename__ = 'project_configurations'

  project_id: Mapped[int] = mapped_column(sqla.Integer, primary_key=True, autoincrement=True)
  project_title: Mapped[str] = mapped_column(sqla.String(255))
  cloud_bucket_name: Mapped[str] = mapped_column(sqla.String(255))
  cloud_bucket_prefix: Mapped[str] = mapped_column(sqla.String(255))
  comma_separated_labels: Mapped[str] = mapped_column(sqla.String(255))
  max_annotation_per_example: Mapped[int] = mapped_column(sqla.Integer)
  completion_deadline = mapped_column(sqla.TIMESTAMP, default=lambda : datetime.utcnow().strftime("%x"))
  created_at = mapped_column(sqla.TIMESTAMP, default=lambda : datetime.utcnow().strftime("%x"))

  def __repr__(self)-> str:
    return (f'Project Configurations\n{"*" * 26} \n' +
            f'project title={self.project_title!r}\n' +
            f'cloud_bucket_name={self.cloud_bucket_name!r}\n' +
            f'cloud_bucket_prefix={self.cloud_bucket_prefix!r}\n' +
            f'comma_separated_labels={self.comma_separated_labels!r}\n' +
            f'max_annotation_per_example={self.max_annotation_per_example!r}\n' +
            f'completion_deadline={self.completion_deadline!r}\n' +
            f'project creation date={self.created_at!r}')


class Annotator(Base):
  __tablename__ = 'annotators'

  annotator_id: Mapped[int] = mapped_column(sqla.Integer, primary_key=True,
                                            nullable=False, autoincrement=True)
  username: Mapped[str] = mapped_column(sqla.String(255))
  email: Mapped[str] = mapped_column(sqla.String(255), nullable=True)

  annotations = relationship("Annotation", back_populates="annotator")
  assigned_annotators = relationship("AssignedAnnotator", back_populates="annotator")

  def __repr__(self)-> str:
    return (f'Annotator(annotator_id={self.annotator_id!r},' +
            f'username={self.username!r}, ' +
            f'email={self.email!r}')


class Annotation(Base):
  __tablename__ = 'annotations'

  annotation_id: Mapped[int] = mapped_column(sqla.Integer, primary_key=True, autoincrement=True)
  label: Mapped[str] = mapped_column(sqla.String(60), nullable=False)
  example_id: Mapped[str] = mapped_column(sqla.String(255), sqla.ForeignKey('examples.example_id'), nullable=False)
  example: Mapped['Example'] = relationship("Example", back_populates='annotations')

  annotator_id: Mapped[int] = mapped_column(sqla.Integer, sqla.ForeignKey('annotators.annotator_id'), nullable=False)
  annotator: Mapped['Annotator'] = relationship("Annotator", back_populates='annotations')

  def __repr__(self)-> str:
    return (f'Annotation(annotation_id={self.annotation_id!r},' +
            f'label={self.label!r}, ' +
            f'annotator={reprlib.repr(self.annotator)}')


class Example(Base):
  __tablename__ = 'examples'

  example_id: Mapped[str] = mapped_column(sqla.String(255), nullable=False, primary_key=True)
  image: Mapped[str] = mapped_column(sqla.String(255), nullable=False)

  annotations: Mapped[List['Annotation']] = relationship("Annotation", back_populates="example")
  assigned_annotators: Mapped[List['AssignedAnnotator']] = relationship("AssignedAnnotator", back_populates="example")

  def __repr__(self)-> str:
    return (f'Example(example_id={self.example_id!r},' +
            f'image_filename={self.image!r}, ' +
            f'annotations={reprlib.repr(self.annotations)}')


class AssignedAnnotator(Base):
  __tablename__ = 'assigned_annotators'

  assignment_id: Mapped[int] = mapped_column(sqla.Integer, primary_key=True, autoincrement=True)
  example_id: Mapped[str] = mapped_column(sqla.String(255), sqla.ForeignKey('examples.example_id'), nullable=False)
  annotator_id: Mapped[int] = mapped_column(sqla.Integer, sqla.ForeignKey('annotators.annotator_id'), nullable=False)

  example: Mapped['Example'] = relationship("Example", back_populates="assigned_annotators")
  annotator: Mapped['Annotator'] = relationship("Annotator", back_populates="assigned_annotators")

  def __repr__(self)-> str:
    return (f'AssignedAnnotator(assignment_id={self.assignment_id!r}, ' +
            f'example_id={self.example_id!r}, ' +
            f'annotator_id={self.annotator_id}')

In [None]:
#@title Init Task Manager
username='postgres'
db_name=database_name
instance_connection_name=instance_connection_name
task_manager = TaskManager(username, root_password, database_name, instance_connection_name)
task_manager.setup()

In [None]:
#@title Get Project Configs
#@markdown Enter Title for Project
project_title = "" #@param {type: "string"}
assert project_title, "Please enter a project title"

#@markdown Specify deadline to complete annotation project
completion_deadline = "2024-05-30" #@param {type:"date"}

#@markdown Enter Data Source ie. GCP cloud bucket name
bucket_name = "" #@param {type: "string"}
assert bucket_name, "Please enter data source for project"

#@markdown Enter Data Source Prefix
bucket_prefix = "" #@param {type: "string"}

#@markdown Enter the maximum number of annotators to assign to each example
max_annotators_per_example = None # @param {type:"integer"}

In [None]:
# Authentication and Permissions
def grant_sql_access(user_email, project_id):
  !gcloud projects add-iam-policy-binding {project_id} \
    --member=user:{user_email} \
    --role="roles/cloudsql.client"

def grant_gcs_access(user_email, bucket_name):
  !gsutil iam ch user:{user_email}:objectViewer {bucket_name}

grant_sql_project_access = partial(grant_sql_access, project_id=project_id)
grant_gcs_bucket_access = partial(grant_gcs_access, bucket_name=bucket_name)

def remove_sql_access(user_email, project_id):
  !gcloud projects remove-iam-policy-binding {project_id} \
    --member=user:{user_email} \
    --role="roles/cloudsql.client"

def remove_gcs_access(user_email, bucket_name):
  !gsutil iam ch -d user:{user_email}:objectViewer {bucket_name}

remove_sql_project_access = partial(remove_sql_access, project_id=project_id)
remove_gcs_bucket_access = partial(remove_gcs_access, bucket_name=bucket_name)

## Label Config

Enter labels for classification task as python list.

Example:
```python
labels = ['cat',
          'dogs'
]
```
If annotators may skip an example, you may add an additional label such as 'skip'.

Example:
```python
labels = ['cat',
          'dogs',
          'skip'
]
```

In [None]:
labels =  ['']

assert labels, "Please enter labels as python list"
print("Labels: ", labels)
comma_separated_labels = ','.join(labels)

In [None]:
#@title Set Project Configs
task_manager.set_project_configs(
    project_title,
    bucket_name,
    bucket_prefix,
    comma_separated_labels,
    max_annotators_per_example,
    completion_deadline)

# View configs
task_manager.get_project_configs()

# Add Users to Database

In [None]:
#@markdown Please Enter Username and password and click to add

def add_user_to_db(input):
  assert USERNAME.value, "Please enter a username to add"
  assert EMAIL.value, "Please enter a user's email"
  assert PASSWORD.value, "Please enter user's password"

  !gcloud sql users create {USERNAME.value} \
    --instance={instance_name} \
    --password={PASSWORD.value}
  print(f"Adding User: username {USERNAME.value}, email: {EMAIL.value}")
  task_manager.add_annotators([{'username': USERNAME.value, 'email': EMAIL.value}])
  task_manager.grant_annotator_access(USERNAME.value)
  grant_sql_project_access(EMAIL.value)
  grant_gcs_bucket_access(EMAIL.value)

def remove_user_from_db(input):
  assert USERNAME.value, "Please enter a username to add"
  assert EMAIL.value, "Please enter a user's email"

  !gcloud sql users delete {USERNAME.value} \
    --instance={instance_name}
  task_manager.remove_annotator([{'username': USERNAME.value}])
  remove_sql_project_access(EMAIL.value)
  remove_gcs_bucket_access(EMAIL.value)
    
USERNAME = widgets.Text(
    value=None,
    placeholder='username',
    description='USER:',
    disabled=False
)

PASSWORD = widgets.Password(
    value=None,
    placeholder='password',
    description='PASSWORD:',
    disabled=False
)

EMAIL = widgets.Text(
    value=None,
    placeholder='Enter email',
    description='EMAIL:',
    disabled=False
)

add_user = widgets.Button(
    value=False,
    description='Add User',
    disabled=False,
    button_style='success',
    tooltip='Description',
    icon='check',
)

remove_user = widgets.Button(
    value=False,
    description='Remove User',
    disabled=False,
    button_style='danger',
    tooltip='Description',
    icon='check',
)

add_user.on_click(add_user_to_db)
remove_user.on_click(remove_user_from_db)

# TODO: Add User button on_click
data_box1 = widgets.HBox([USERNAME, PASSWORD, EMAIL])
data_box2 = widgets.HBox([add_user, remove_user])
display(widgets.VBox([data_box1, data_box2]))

In [None]:
#@markdown # View list of users added to SQL Instance
users = !gcloud sql users list --instance={instance_name} --format="value(name)"
print("List of Added users: ", users)

In [None]:
task_manager.list_annotators()

# Upload Data

Data to be labeled can be added by reading directly from the GCP bucket path or from a csv file. In former step, the ```task_manager.add_tasks()``` is called with a csv_path name. The csv file should be located in the same bucket as the data.

In [None]:
task_manager.add_tasks(bucket_name, bucket_prefix, csv_filename='metadata.csv')

In [None]:
task_manager.list_tasks(limit=10)

# Assign Tasks to Annotators

In [None]:
# task_manager.assign_tasks(max_annotators_per_example=max_annotators_per_example)

# Inspect Task Assignments

In [None]:
task_manager.show_assignments()

# Export Annotations

In [None]:
result = task_manager.get_completed_annotations().all()
df = pd.DataFrame(result)
df

In [None]:
df.to_csv('annotations.csv')

# Delete Data/ Database

In [None]:
# !gcloud sql databases delete {database_name} \
# --instance=instance_name
# !gcloud sql instances delete {instance_name}