# Checkout kpt package

In [None]:
!kpt pkg get https://github.com/GoogleCloudPlatform/blueprints/tree/main/catalog/gke

# KRM SDK

In [None]:
# YamlPackage, YamlFile, YamlFileObject

import yaml, os, re

class YamlPackage:
  """
  YamlPackage wraps a directory tree on the local disk.
  It provides a way to iterate over YAML files in the package as YamlFile instances.
  """

  def __init__(self, pkg_path):
    self.pkg_path = pkg_path
    self.prefix_regex = re.compile(r"^./")
    self.files_iter = None
  
  def __iter__(self):
    self.walk_iter = os.walk(self.pkg_path)
    return self
  
  # TODO: skip invisible dirs?
  def __next__(self):
    while True:
      while self.files_iter != None:
        try:
          f = self.files_iter.__next__() # raises StopIteration
        except StopIteration as exp:
          break
        # skip invisible files
        if f.startswith("."):
          continue
        # skip non-yaml files
        if not f.endswith(".yaml") and not f.endswith(".yml"):
          continue
        # expand full path
        f = os.path.join(self.parent_path, f)
        # strip optional prefix
        f = self.prefix_regex.sub("", f, 1)
        return YamlFile(f)
      self.parent_path, _, files = self.walk_iter.__next__() # raises StopIteration
      self.files_iter = files.__iter__()
  
  def find_object(self, obj_ref, skip_invalid=True):
    if type(obj_ref) == ObjectReference:
      obj_ref = obj_ref.to_str()
    for yaml_file in self:
      for obj in yaml_file:
        try:
          if obj.to_ref().to_str() == obj_ref:
            return obj
        except ObjectFieldNotFound as exp:
          if not skip_invalid:
            raise InvalidObject(obj, exp, yaml_file.file_path)
    return None

class YamlFile:
  """
  YamlFile wraps a YAML file on the local disk.
  It provides a way to iterate over objects in the file as YamlFileObject instances.
  It also provides a way to write to the file with raw yaml or from an object or
  list of objects.
  """

  def __init__(self, file_path, mode='r'):
    self.file_path = file_path
    self.mode = mode
    self.obj_iter = None
    
  def __enter__(self):
    self.file = open(self.file_path, self.mode, encoding="utf-8")
    return self

  def __exit__(self, exc_type, exc_val, exc_tb):
    self.file.close()
  
  def write_obj(self, obj):
    # unwrap YamlFileObject
    objs = obj.obj
    self.write_yaml(yaml.safe_dump(obj, default_flow_style=False))
  
  def write_objs(self, objs):
    # unwrap YamlFileObject
    objs = [obj.obj for obj in objs]
    obj_yaml = yaml.safe_dump_all(objs, default_flow_style=False)
    #print(obj_yaml)
    self.write_yaml(obj_yaml)
  
  def write_yaml(self, objYaml):
    self.file.write(objYaml)

  def read_obj(self):
    return YamlFileObject(yaml.safe_load(self.read_yaml()), self)
  
  def read_objs(self):
    return [YamlFileObject(f, self, i) for i, f in enumerate(yaml.safe_load_all(self.read_yaml()))]
  
  def read_yaml(self):
    return self.file.read()
  
  def __iter__(self):
    with self as f:
      self.obj_iter = f.read_objs().__iter__()
    return self
  
  def __next__(self):
    return self.obj_iter.__next__()

# object context manager with field get/set backed by a YamlFile.
class YamlFileObject:
  """
  YamlFileObject wraps a dictionary parsed from an object in a YAML file.
  It provides a way to read, modify, and write to the file using a dictionary.
  It also provides helper methods for manupulating dictionaries with compound
  dictionaries and lists.
  """

  def __init__(self, obj, yaml_file, index=0):
    self.obj = obj # dict
    self.yaml_file = yaml_file
    self.index = index
    self.objs = None # []YamlFileObject
  
  # read on enter
  def __enter__(self):
    with YamlFile(self.yaml_file.file_path, 'r') as f:
      #print("Reading: {}".format(self.yaml_file.file_path))
      self.objs = f.read_objs()
      self.obj = self.objs[self.index].obj # unwrap YamlFileObject
    return self

  # write on exit
  def __exit__(self, exc_type, exc_val, exc_tb):
    self.objs[self.index] = YamlFileObject(self.obj, self.yaml_file, self.index)
    with YamlFile(self.yaml_file.file_path, 'w') as f:
      #print("Writing: {}".format(self.yaml_file.file_path))
      f.write_objs(self.objs)

  def has_field(self, field_path):
    curr = self.obj
    for field in field_path:
      if field not in curr:
        return False
      curr = curr[field]
    return True

  def get_field(self, field_path):
    curr = self.obj
    for field in field_path:
      if field not in curr:
        raise ObjectFieldNotFound(field, field_path)
      curr = curr[field]
    return curr
  
  def set_field(self, field_path, value):
    curr = self.obj
    for field in field_path[:-1]:
      # TODO: add support for populating intermediate lists/maps
      if field not in curr:
        raise ObjectFieldNotFound(field, field_path)
      curr = curr[field]
    curr[field_path[-1]] = value
  
  def to_yaml(self):
    return yaml.safe_dump(self.obj, default_flow_style=False)
  
  def to_ref(self):
    if not self.has_field(["metadata"]):
      raise ObjectFieldNotFound("metadata")
    if self.has_field(["metadata", "namespace"]):
      namespace = self.get_field(["metadata", "namespace"])
    else:
      namespace = None
    return ObjectReference(
          api_version=self.get_field(["apiVersion"]), 
          kind=self.get_field(["kind"]),
          name=self.get_field(["metadata", "name"]),
          namespace=namespace)

class ObjectReference:
  def __init__(self, api_version, kind, name, namespace=None):
    self.api_version = api_version
    self.kind = kind
    self.name = name
    self.namespace = namespace

  def to_str(self):
    if self.namespace != None:
      return "{}/{}/namespaces/{}/{}".format(
          self.api_version, 
          self.kind,
          self.namespace,
          self.name)
    return "{}/{}/{}".format(
          self.api_version, 
          self.kind,
          self.name)

class ObjectFieldNotFound(Exception):
  def __init__(self, field, field_path=None):
    if field_path:
      super().__init__("field not found: {} in path {}".format(field, field_path))
    else:
      super().__init__("field not found: {}".format(field))
    self.field = field
    self.field_path = field_path

class InvalidObject(Exception):
  def __init__(self, field, cause, file_path=None):
    if file_path:
      super().__init__("invalid object: {} in file {}: {}".format(obj, file_path, cause))
    else:
      super().__init__("invalid object: {}: {}".format(obj, cause))
    self.obj = obj
    self.cause = cause
    self.file_path = file_path


# Print Funcs

In [None]:
# print_pkg_file_paths
import pandas as pd

def print_pkg_file_paths(pkg_path):
  data = []
  for f in YamlPackage(pkg_path):
    data.append([f.file_path])
  df = pd.DataFrame(data, columns=["File Path"])
  display(df)

print("Package Files:")
print_pkg_file_paths(".")

In [None]:
# print_pkg_obj_refs

import pandas as pd

def print_pkg_obj_refs(pkg_path):
  data = []
  for yaml_file in YamlPackage(pkg_path):
    for obj in yaml_file:
      ref = obj.to_ref()
      data.append([yaml_file.file_path, ref.api_version, ref.kind, ref.namespace, ref.name])
  df = pd.DataFrame(data, columns=["File Path", "API Version", "Kind", "Namespace", "Name"])
  display(df)

print("Package Object References:")
print_pkg_obj_refs(".")

In [None]:
def print_pkg_yaml(pkg_path):
  for yaml_file in YamlPackage(pkg_path): 
    for obj in yaml_file:
      print("---")
      # TODO: preserve comments & order
      # TODO: inject file path as annotation
      print(obj.to_yaml(), end="")

def print_objs_yaml(objs):
  for obj in objs:
    print("---")
    print(obj.to_yaml(), end="")

print("Package as YAML:")
print_pkg_yaml(".")

# Example deployment.yaml

In [None]:
# from yaml
with open("deployment.yaml", 'w', encoding="utf-8") as f:
  f.write("""\
apiVersion: apps/v1
kind: Deployment
metadata:
  labels:
    app: nginx
  name: nginx
  namespace: default
spec:
  replicas: 1
  selector:
    matchLabels:
      app: nginx
  template:
    metadata:
      labels:
        app: nginx
    spec:
      containers:
      - image: nginx:1.7.9
        name: nginx
        ports:
        - containerPort: 80
""")

with open("deployment.yaml", 'r', encoding="utf-8") as f:
  print(f.read())

In [None]:
# from client object
# see https://github.com/kubernetes-client/python/blob/master/examples/notebooks/create_deployment.ipynb
from kubernetes import client

deployment = client.V1Deployment(
  api_version="apps/v1",
  kind="Deployment",
  metadata=client.V1ObjectMeta(
    name="nginx",
    namespace="default",
    labels={"app": "nginx"},
  ),
  spec=client.V1DeploymentSpec(
    replicas=1,
    selector=client.V1LabelSelector(
      match_labels={"app": "nginx"},
    ),
    template=client.V1PodTemplateSpec(
      metadata=client.V1ObjectMeta(
          labels={"app": "nginx"},
      ),
      spec=client.V1PodSpec(
        containers=[
          client.V1Container(
            name="nginx",
            image="nginx:1.7.9",
            ports=[
              client.V1ContainerPort(
                container_port=80,
              ),
            ],
          ),
        ],
      ),
    ),
  ),
)

with open("deployment.yaml", 'w', encoding="utf-8") as f:
  obj = client.ApiClient().sanitize_for_serialization(deployment)
  f.write(yaml.safe_dump(obj, default_flow_style=False))

with open("deployment.yaml", 'r', encoding="utf-8") as f:
  print(f.read())

# set_annotation

In [None]:
# Define: set_annotation

def set_annotation(obj, key, value):
  if not obj.has_field(["metadata"]):
    obj.set_field("metadata", {})
  if not obj.has_field(["metadata", "annotations"]):
    obj.set_field(["metadata", "annotations"], {})
  obj.set_field(["metadata", "annotations", key], value)

In [None]:
# Set an annotation (example-key: example-value) on a specific object

pkg = YamlPackage(".")
deployment_ref = ObjectReference(
    api_version="apps/v1",
    kind="Deployment",
    name="nginx",
    namespace="default",
)

print("Before:\n---")
print(pkg.find_object(deployment_ref).to_yaml())

with pkg.find_object(deployment_ref) as obj:
  set_annotation(obj, "example-key", "example-value")

print("After:\n---")
print(pkg.find_object(deployment_ref).to_yaml())

# set_namespace

In [None]:
# Define: is_obj_valid, is_obj_local_config, is_obj_namespace_scoped

def is_obj_valid(obj):
  if not obj.has_field(["kind"]):
    return False
  if not obj.has_field(["metadata"]):
    return False
  if not obj.has_field(["metadata", "name"]):
    return False
  return True

def is_obj_local_config(obj):
  if not is_obj_valid(obj):
    return False
  if obj.has_field(["metadata", "annotation"]):
    if obj.has_field(["metadata", "annotation", "config.kubernetes.io/local-config"]):
      if obj.get_field(["metadata", "annotation", "config.kubernetes.io/local-config"]) == "true":
        return True
  return False

def is_obj_namespace_scoped(obj):
  if not is_obj_valid(obj):
    return False
  # ignore objects without a namespace
  # TODO: use cluster resource mapping to check scope (requires client & server)
  if obj.has_field(["metadata", "namespace"]):
    return True
  return False

def set_namespace(obj, value):
  if not is_obj_namespace_scoped(obj):
    return
  obj.set_field(["metadata", "namespace"], value)

In [None]:
# Set namespace to "example-namespace" on all namespaced objects in the gke package

for yaml_file in YamlPackage("gke"): 
  for obj in yaml_file:
    if is_obj_local_config(obj):
      continue
    with obj as o:
     set_namespace(o, "example-namespace")

print("\nPackage Objects:")
print_pkg_obj_refs("gke")

# set_label

In [None]:
# Define: set_label

def set_label(obj, key, value):
  if not obj.has_field(["metadata", "labels"]):
    obj.set_field(["metadata", "labels"], {})
  obj.set_field(["metadata", "labels", key], value)

In [None]:
# Set a label (example-label: example-value) on a specific object

pkg = YamlPackage(".")
deployment_ref = ObjectReference(
    api_version="apps/v1",
    kind="Deployment",
    name="nginx",
    namespace="default",
)
deployment = pkg.find_object(deployment_ref)
  
with deployment as o:
  set_label(o, "example-label", "example-value")

with open("deployment.yaml", 'r', encoding="utf-8") as f:
  print(f.read())

# select_by_label

In [None]:
# Define: has_label, select_by_label

def has_label(obj, key, value=None):
  if value == None:
    return obj.has_field(["metadata", "labels", key])
  if obj.has_field(["metadata", "labels", key]):
    return obj.get_field(["metadata", "labels", key]) == value

def select_by_label(pkg, key, value=None):
  objs = []
  for file in pkg:
    for obj in file:
      if has_label(obj, key, value):
        objs.append(obj)
  return objs

In [None]:
# Find objects with a specific label (example-label: example-value)

pkg = YamlPackage(".")
objs = select_by_label(pkg, "example-label", "example-value")
print_objs_yaml(objs)