Skip to content

segment tree

Changi Cho edited this page Oct 7, 2022 · 13 revisions

세그먼트트리

구간의 값들을 미리 계한새놓고, 이분탐색으로 구간의 연산 결과를 빨리 찾아내는 방법

트리의 크기를 만들 때에는 원본 배열의 크기에 4배로 만들어주면됨

참고자료

구현

일반 세그먼트 트리 (index update, range update)

template <class T>
struct SegmentTree {
  const T NULL_VALUE = 0;
  int size;
  vector<T> tree;

  SegmentTree(vector<T> &array) {
    size = array.size();

    tree.resize(size * 4);

    init(1, 0, size - 1, array);
  }

  // operation part

  T operation(T x, T y) { return x + y; }

  // initialize part

  void init(int node, int start, int end, vector<T> &array) {
    if (start == end) {
      tree[node] = array[start];
      return;
    }

    int mid = (start + end) / 2;
    init(node * 2, start, mid, array);
    init(node * 2 + 1, mid + 1, end, array);

    tree[node] = operation(tree[node * 2], tree[node * 2 + 1]);
  }

  // query part

  T query(int left, int right) { return query(1, 0, size - 1, left, right); }

  T query(int node, int start, int end, int left, int right) {
    if (right < start || end < left) {
      return NULL_VALUE;
    }
    if (left <= start && end <= right) {
      return tree[node];
    }

    int mid = (start + end) / 2;
    T leftPart = query(node * 2, start, mid, left, right);
    T rightPart = query(node * 2 + 1, mid + 1, end, left, right);
    return operation(leftPart, rightPart);
  }

  // update part

  void update(int index, T diff) { update(1, 0, size - 1, index, index, diff); }

  void update(int left, int right, T diff) {
    update(1, 0, size - 1, left, right, diff);
  }

  void update(int node, int start, int end, int left, int right, T diff) {
    if (right < start || end < left) {
      return;
    }
    if (start == end) {
      tree[node] += diff;
      return;
    }

    int mid = (start + end) / 2;
    update(node * 2, start, mid, left, right, diff);
    update(node * 2 + 1, mid + 1, end, left, right, diff);

    tree[node] = operation(tree[node * 2], tree[node * 2 + 1]);
  }
};

lazy propagation segment tree

lazy propagation을 이용해 range를 update하는 메소드 추가

template <class T>
struct SegmentTree {
  const T NULL_VALUE = 0;
  int size;
  vector<T> tree;
  vector<T> lazyDiff;

  SegmentTree(vector<T> &array) {
    size = array.size();

    tree.resize(size * 4);
    lazyDiff.resize(size * 4);

    init(1, 0, size - 1, array);
  }

  // operation part

  T operation(T x, T y) { return x + y; }

  // initialize part

  void init(int node, int start, int end, vector<T> &array) {
    if (start == end) {
      tree[node] = array[start];
      return;
    }

    int mid = (start + end) / 2;
    init(node * 2, start, mid, array);
    init(node * 2 + 1, mid + 1, end, array);

    tree[node] = operation(tree[node * 2], tree[node * 2 + 1]);
  }

  // query part

  T query(int left, int right) { return query(1, 0, size - 1, left, right); }

  T query(int node, int start, int end, int left, int right) {
    // only use lazy propagation
    _lazy(node, start, end);

    if (right < start || end < left) {
      return NULL_VALUE;
    }
    if (left <= start && end <= right) {
      return tree[node];
    }

    int mid = (start + end) / 2;
    T leftPart = query(node * 2, start, mid, left, right);
    T rightPart = query(node * 2 + 1, mid + 1, end, left, right);
    return operation(leftPart, rightPart);
  }

  // update part

  void update(int index, T diff) { update(1, 0, size - 1, index, index, diff); }

  void update(int left, int right, T diff) {
    update(1, 0, size - 1, left, right, diff);
  }

  void update(int node, int start, int end, int left, int right, T diff) {
    if (right < start || end < left) {
      return;
    }
    if (start == end) {
      tree[node] += diff;
      return;
    }

    int mid = (start + end) / 2;
    update(node * 2, start, mid, left, right, diff);
    update(node * 2 + 1, mid + 1, end, left, right, diff);

    tree[node] = operation(tree[node * 2], tree[node * 2 + 1]);
  }

  // update lazy part

  void _lazy(int node, int start, int end) {
    if (lazyDiff[node] == 0) {
      return;
    }

    tree[node] += (end - start + 1) * lazyDiff[node];
    if (start != end) {
      lazyDiff[node * 2] += lazyDiff[node];
      lazyDiff[node * 2 + 1] += lazyDiff[node];
    }
    lazyDiff[node] = 0;
  }

  void update_lazy(int left, int right, T diff) {
    update_lazy(1, 0, size - 1, left, right, diff);
  }

  void update_lazy(int node, int start, int end, int left, int right, T diff) {
    _lazy(node, start, end);

    if (right < start || end < left) {
      return;
    }
    if (start == end) {
      tree[node] += diff;
      return;
    }
    if (left <= start && end <= right) {
      tree[node] += (end - start + 1) * diff;

      lazyDiff[node * 2] += diff;
      lazyDiff[node * 2 + 1] += diff;
      return;
    }

    int mid = (start + end) / 2;
    update_lazy(node * 2, start, mid, left, right, diff);
    update_lazy(node * 2 + 1, mid + 1, end, left, right, diff);

    tree[node] = operation(tree[node * 2], tree[node * 2 + 1]);
  }
};
Clone this wiki locally