In [1]:
## imports and environment variables
# imports
import json
import os
import urllib3
from google.cloud import storage
import google.auth
import google.auth.transport.requests
import polling2
from abc import ABC, abstractmethod

from dataclasses import asdict, dataclass, field
from dataclasses_json import config, dataclass_json

from typing import List, Optional
from enum import Enum
from urllib.parse import quote, urlparse, urlunparse
from requests import Request, Response, Session

# workspace environment variables
ws_name = os.environ["WORKSPACE_NAME"]
ws_project = os.environ["WORKSPACE_NAMESPACE"]
ws_bucket = os.environ["WORKSPACE_BUCKET"]

print(f"workspace name = {ws_name}")
print(f"workspace project = {ws_project}")
print(f"workspace bucket = {ws_bucket}")

workspace name = anvil_cmg_ingest_resources
workspace project = dsp-data-ingest
workspace bucket = gs://fc-9cd4583e-7855-4b5e-ae88-d8971cfd5b46


In [2]:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

class RestClient(ABC):
    """Provides a small wrapper around the requests library.

    Could be useful for hooking in additional functionality in the future.
    """

    def _get(self, url, **kwargs) -> Request:
        return Request("get", url, **kwargs)

    def _post(self, url, **kwargs) -> Request:
        return Request("post", url, **kwargs)

    def _put(self, url, **kwargs) -> Request:
        return Request("put", url, **kwargs)

    def _delete(self, url, **kwargs) -> Request:
        return Request("delete", url, **kwargs)

    def send(self, request: Request, **kwargs) -> Response:
        """Prepare and send a request. Return the response

        IMPORTANT: Skips ssl cert validation for the time being.
        """
        with Session() as session:
            prepared = request.prepare()
            response = session.send(prepared, verify=False, **kwargs)

        return response

    def stream_to_file(self, outfile, request: Request, **kwargs) -> int:
        """Prepare and send a request. Return the response

        IMPORTANT: Skips ssl cert validation for the time being.
        """
        n_bytes = 0
        with Session() as session:
            prepared = request.prepare()
            with session.send(
                prepared,
                stream=True,
                verify=False,
                **kwargs,
            ) as r:
                # Ensure the request was successful
                r.raise_for_status()
                # Chunk size not guarunteed see:
                # https://docs.python-requests.org/en/latest/api/#requests.Response.iter_content

                for chunk in r.iter_content(chunk_size=1024 * 8):
                    outfile.write(chunk)
                    n_bytes += len(chunk)
        return n_bytes


class AuthenticatedClient(RestClient):
    @abstractmethod
    def authenticate(self, request):
        """Abstract method that will be called to authenticate
        a Request before it is sent.
        """

    def send(self, request: Request, **kwargs) -> Response:
        """Authenticate and then send request.

        Params:
            request: Request = Original unauthenticated request

        Returns: Response
        """
        authenticated_request = self.authenticate(request)
        return super().send(authenticated_request, **kwargs)
    


class DataTypeEnum(Enum):
    STRING = "string"
    BOOLEAN = "boolean"
    BYTES = "bytes"
    DATE = "date"
    DATETIME = "datetime"
    DIRREF = "dirref"
    FILERER = "fileref"
    FLOAT = "float"
    FLOAT64 = "float64"
    INTEGER = "integer"
    INT64 = "int64"
    NUMERIC = "numeric"
    RECORD = "record"
    TEXT = "text"
    TIME = "time"
    TIMESTAMP = "timestamp"


class ParititionModeEnum(Enum):
    NONE = "none"
    DATE = "date"
    INT = "int"


class CloudPlatformEnum(Enum):
    GCP = "gcp"
    AZURE = "azure"


class FormatEnum(Enum):
    CSV = "csv"
    JSON = "json"
    ARRAY = "array"


class UpdateEnum(Enum):
    APPEND = "append"
    REPLACE = "replace"


@dataclass
class TDRResponse:
    id: str
    job_status: str
    status_code: int
    description: Optional[str] = None
    submitted: Optional[str] = None
    completed: Optional[str] = None
    class_name: Optional[str] = None


@dataclass
class TDRErrorResponse:
    message: Optional[str] = None
    errorDetail: Optional[List[str]] = None
        
        
@dataclass
class AssetTableModel:
    name: str
    columns: List[str]


@dataclass
class AssetModel:
    name: str
    tables: List[AssetTableModel]
    rootTable: str
    rootColumn: str
    follow: Optional[List[str]] = None


@dataclass
class RelationshipTermModel:
    table: str
    column: str

        
@dataclass_json
@dataclass
class RelationshipModel:
    name: str
    relationshipFrom: RelationshipTermModel = field(metadata=config(field_name="from"))
    to: RelationshipTermModel


@dataclass
class IntPartitionOptionsModel:
    column: str
    min: int
    max: int
    interval: int


@dataclass
class DatePartitionOptionsModel:
    column: Optional[str] = None


@dataclass
class ColumnModel:
    name: str
    datatype: DataTypeEnum
    array_of: Optional[bool] = None


@dataclass
class TableModel:
    name: str
    columns: List[ColumnModel]
    primaryKey: Optional[List[str]] = None
    partitionMode: Optional[ParititionModeEnum] = None
    datePartitionOptions: Optional[DatePartitionOptionsModel] = None
    intPartitionOptions: Optional[IntPartitionOptionsModel] = None
    rowCount: Optional[int] = None


@dataclass
class SchemaModel:
    tables: List[TableModel]
    relationships: Optional[List[RelationshipModel]] = None
    assets: Optional[List[AssetModel]] = None

@dataclass
class Storage:
    region: Optional[str] = None
    cloudResource: Optional[str] = None
    cloudPlatform: Optional[str] = None
    
@dataclass 
class BigQueryTable:
    name: Optional[str] = None
    id: Optional[str] = None
    qualified_name: Optional[str] = None
    link: Optional[str] = None
    sampleQuery: Optional[str] = None

@dataclass
class BigQuery:
    datasetName: Optional[str] = None
    datasetId: Optional[str] = None
    projectId: Optional[str] = None
    link: Optional[str] = None
    tables: Optional[List[BigQueryTable]] = None

@dataclass
class ParquetTable:
    name: Optional[str] = None
    url: Optional[str] = None
    sasToken: Optional[str] = None

@dataclass
class Parquet:
    datasetName: Optional[str] = None
    datasetId: Optional[str] = None
    storageAccountId: Optional[str] = None
    url: Optional[str] = None
    sasTtoken: Optional[str] = None
    tables: Optional[List[ParquetTable]] = None
        
@dataclass
class AccessInformation:
    bigQuery: Optional[BigQuery] = None
    parquet: Optional[Parquet] = None
        

@dataclass
class Item:
    id: Optional[str] = None
    name: Optional[str] = None
    description: Optional[str] = None
    defaultProfileId: Optional[str] = None
    createdDate: Optional[str] = None
    storage: Optional[List[Storage]] = None
    secureMonitoringEnabled: Optional[bool] = None
    cloudPlatform: Optional[str] = None
    dataProject: Optional[str] = None
    storageAccount: Optional[str] = None
    phsId: Optional[str] = None
    selfHosted: Optional[bool] = None


@dataclass
class RoleMap:
    additional_prop1: List[str]

@dataclass_json
@dataclass
class TDRDatasetDetail:
    id: Optional[str] = None
    name: Optional[str] = None
    description: Optional[str] = None
    defaultProfileId: Optional[str] = None
    dataProject: Optional[str] = None
    defaultSnapshotId: Optional[str] = None
    schema: Optional[SchemaModel] = None
    storage: Optional[List[Storage]] = None
    secureMonitoringEnabled: Optional[bool] = None
    phsId: Optional[str] = None
    accessInformation: Optional[AccessInformation] = None
    selfHosted: Optional[bool] = None
    
        
@dataclass_json
@dataclass
class TDRDataset:
    total: int
    filteredTotal: int
    items: List[Item]
    roleMap: RoleMap
        
        
@dataclass
class TDRDatasetSearchRequest:
    filter: Optional[str] = None
    sort: Optional[int] = None
    direction: Optional[int] = None
        

@dataclass
class TDRDatasetRequest:
    name: str
    defaultProfileId: str
    schema: SchemaModel
    description: Optional[str] = None
    region: Optional[str] = None
    cloudPlatform: Optional[CloudPlatformEnum] = None
    enableSecureMonitoring: Optional[bool] = None
    phsId: Optional[str] = None
    experimentalSelfHosted: Optional[bool] = None


@dataclass
class TDRIngestRequest:
    table: str
    format: FormatEnum
    path: Optional[str] = None
    records: Optional[List[str]] = None
    load_tag: Optional[str] = None
    profile_id: Optional[str] = None
    max_bad_records: Optional[int] = None
    max_failed_file_loads: Optional[int] = None
    ignore_unknown_values: Optional[bool] = None
    csv_field_delimiter: Optional[str] = None
    csv_quote: Optional[str] = None
    csv_skip_leading_rows: Optional[int] = None
    csv_allow_quoted_newlines: Optional[bool] = None
    csv_null_marker: Optional[str] = None
    csv_generate_row_ids: Optional[bool] = None
    resolve_existing_files: Optional[bool] = None
    transactionId: Optional[str] = None
    updateStrategy: Optional[UpdateEnum] = None


def as_dict(obj):
    return {
        field: value.value if isinstance(value, Enum) else value for field, value in obj if value is not None
    }


class TDRClient(AuthenticatedClient):
    """Client for interfacing with the TDR.

    Enables communication via the TDR API:
    https://data.terra.bio/
    """

    def __init__(
        self,
        host=os.environ.get("ZEBRAFISH_TDR_HOST", "data.terra.bio"),
        token=os.environ.get("ZEBRAFISH_TDR_TOKEN"),
        scheme="https",
    ):
        if token is None:
            print("TDRClient has no authentication token.")

        self.host = host
        self.token = token
        self.scheme = scheme

    def _build_url(self, path: str) -> str:
        """Build url and ensure it is structured correctly."""
        url = f"{self.scheme}://{self.host}/{path}"
        parts = urlparse(url)
        return urlunparse(parts._replace(path=quote(parts.path.replace("//", "/"))))

    def authenticate(self, request: Request) -> Request:
        """Prepare request with appropriate authorization inforamtion."""
        if self.token:
            request.headers.update({"Authorization": f"Bearer {self.token}"})
        return request

    def handle_response(
        self, response: Response, success_code: List[int]
    ) -> TDRResponse:
        """Handles response based on given success_code"""
        
        if response.status_code not in success_code:
            msg = response.text if response.text else response.message
            print(f"msg = {msg}")
            raise Exception(f"Request failed: {response.status_code}")

        return TDRResponse(**response.json())
    
    def check_job_polling_response(self, response):
        return self.handle_response(response, [200, 202]).completed is not None
    
    def create_dataset(self, req: TDRDatasetRequest) -> TDRResponse:
        """Create a new dataset."""
        url = self._build_url("/api/repository/v1/datasets")
        params = {"json": asdict(req, dict_factory=as_dict)}
        return self.handle_response(self.send(self._post(url, **params)), [200, 202])
    
    def get_dataset_by_name(self, req: TDRDatasetSearchRequest) -> TDRDataset:
        """Get dataset by name"""
        params = {"params": asdict(req, dict_factory=as_dict)}
        url = self._build_url("/api/repository/v1/datasets")

        response = self.send(self._get(url, **params))
        
        if response.status_code not in [200]:
            msg = response.text if response.text else response.message
            print(msg)
            raise Exception(f"Request failed: {response.status_code}")
    
        return TDRDataset(**response.json())
    
    def get_dataset_details(self, id: str) -> TDRDataset:
        """Get dataset by name"""
        url = self._build_url(f"/api/repository/v1/datasets/{id}?include=ACCESS_INFORMATION")

        response = self.send(self._get(url))
                
        if response.status_code not in [200]:
            msg = response.text if response.text else response.message
            print(msg)
            raise Exception(f"Request failed: {response.status_code}")
    
        return TDRDatasetDetail.from_json(json.dumps(response.json()))

    def job_status(self, id: str) -> TDRResponse:
        """Get TDR Job Status."""
        url = self._build_url(f"/api/repository/v1/jobs/{id}")

        # 200 = ok, 202 = running
        return self.handle_response(self.send(self._get(url)), [200, 202])
    
    def job_status_result(self, id: str) -> TDRErrorResponse:
        """Get TDR Job Status."""
        url = self._build_url(f"/api/repository/v1/jobs/{id}/result")
        response = self.send(self._get(url))
        
        # 200 = ok, 202 = running
        return TDRErrorResponse(**response.json())
    
    def poll_job_status(self, id: str) -> TDRResponse:
        """Poll TDR Job Status until Job is completed"""
        url = self._build_url(f"/api/repository/v1/jobs/{id}")

        # Poll until completed value is not None
        resp = polling2.poll(
            lambda: self.send(self._get(url)),
            check_success=self.check_job_polling_response,
            step=30,
            poll_forever=True,
        )
        
        # If "job_status": "failed" get the result details -- else send regular response
        if self.handle_response(resp, [200]).job_status == "failed":
            return self.job_status_result(id)
        else:
            return self.handle_response(resp, [200])

    def ingest(self, id: str, req: TDRIngestRequest) -> TDRResponse:
        """Ingest"""

        url = self._build_url(f"/api/repository/v1/datasets/{id}/ingest")

        params = {"json": asdict(req, dict_factory=as_dict)}
        
        print(f"params: {params}")
        return self.handle_response(self.send(self._post(url, **params)), [200, 202])